Whenever I try to log my run it throws me the following error:
MlflowException: Failed to save runnable sequence: {'0': 'RunnableParallel<query,context> -- Failed to save runnable sequence: {\'context\': "RunnableSequence -- Failed to save runnable sequence: {\'2\': \'VectorStoreRetriever -- The `loader_fn` must be a function that returns a retriever.\'}."}.', '2': "ChatDatabricks -- 1 validation error for ChatDatabricks\nendpoint\n Field required [type=missing, input_value={'extra_params': {}, 'max...ks', 'temperature': 0.0}, input_type=dict]\n For further information visit https://errors.pydantic.dev/2.10/v/missing"}.
I am not sure what exactly is missing here. Initially I was using parameters such as temperature, max_tokens but decided to remove those as well. Yet the error was there.
- my_retriever is basically retriever using DatabricksVectorSearch.
- There are no additional parameters to ChatDatabricks.
I am not sure how do I fix this.
Can someone help me with this?
Here is my simple code:
from databricks_langchain import DatabricksVectorSearch, ChatDatabricks
# excluding all langchain imports
my_index = DatabricksVectorSearch(
endpoint=vs_endpoint,
index_name=my_index_name,
columns=["ID", "TEXT"]
)
my_retriever = my_index.as_retriever(search_kwargs={"k": 3, "query_type": "HYBRID"})
prompt = PromptTemplate.from_template(
template="""Some template: {query} and {context} """
)
def format_context(text):
return modified(text)
llm_endpoint = ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct")
chain = (
RunnableMap({
"query": RunnableLambda(itemgetter("messages")),
"context": RunnableLambda(itemgetter("messages")) | my_retriever | RunnableLambda(format_context),
})
| prompt
| llm_endpoint
| StrOutputParser()
)
from mlflow.models import infer_signature
import mlflow
model_name = f"some_model_name"
input_example = "input_example"
resp = chain.invoke(input_example)
with mlflow.start_run(run_name="run_name") as run:
signature = infer_signature(input_example, resp)
model_info = mlflow.langchain.log_model(
chain,
loader_fn=my_retriever,
artifact_path="path_to_artifact",
registered_model_name=model_name,
input_example=input_example,
signature=signature
)