[8586fsbgpb] An error occurred while loading the model. Failed to load the pickled function from a hexadecimal string. Error: Can't get attribute 'transform_input' on <module '__main__' from '/opt/conda/envs/mlflow-env/bin/gunicorn'>.
I´m using the function to transform input and output on this way
def transform_input(**request):
print('Type of prompt',type(request["prompt"]))
request["messages"] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": request["prompt"]},
]
request["stop"] = ['\n\n']
print("Request format",request)
return request
def transform_output(response):
return response['candidates'][0]
# If using serving endpoint, the model serving endpoint is created in `02_[chat]_mlflow_logging_inference`
llm = Databricks(endpoint_name='llama2-7b-chat-completion',
transform_input_fn=transform_input,
transform_output_fn=transform_output,extra_params={"temperature":0.01,"max_tokens": 300})
Is there anything else I´m missing to avoid this error ?