cancel
Showing results for 
Search instead for 
Did you mean: 
Generative AI
Explore discussions on generative artificial intelligence techniques and applications within the Databricks Community. Share ideas, challenges, and breakthroughs in this cutting-edge field.
cancel
Showing results for 
Search instead for 
Did you mean: 

Unable to log MLFlow run for LangChain chain while using databricks-langchain library

heramb13
New Contributor II

Whenever I try to log my run it throws me the following error:

 

MlflowException: Failed to save runnable sequence: {'0': 'RunnableParallel<query,context> -- Failed to save runnable sequence: {\'context\': "RunnableSequence -- Failed to save runnable sequence: {\'2\': \'VectorStoreRetriever -- The `loader_fn` must be a function that returns a retriever.\'}."}.', '2': "ChatDatabricks -- 1 validation error for ChatDatabricks\nendpoint\n  Field required [type=missing, input_value={'extra_params': {}, 'max...ks', 'temperature': 0.0}, input_type=dict]\n    For further information visit https://errors.pydantic.dev/2.10/v/missing"}.

 

I am not sure what exactly is missing here. Initially I was using parameters such as temperature, max_tokens but decided to remove those as well. Yet the error was there.

  • my_retriever is basically retriever using DatabricksVectorSearch.
  • There are no additional parameters to ChatDatabricks. 

I am not sure how do I fix this. 

Can someone help me with this?

Here is my simple code:

 

from databricks_langchain import DatabricksVectorSearch, ChatDatabricks
# excluding all langchain imports

my_index = DatabricksVectorSearch(
    endpoint=vs_endpoint,
    index_name=my_index_name,
    columns=["ID", "TEXT"]
    )
my_retriever = my_index.as_retriever(search_kwargs={"k": 3, "query_type": "HYBRID"})

prompt = PromptTemplate.from_template(
    template="""Some template: {query} and {context} """
)
def format_context(text):
    return modified(text)
llm_endpoint = ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct")

chain = (
    RunnableMap({
        "query": RunnableLambda(itemgetter("messages")),
        "context": RunnableLambda(itemgetter("messages")) | my_retriever | RunnableLambda(format_context),
    })
    | prompt
    | llm_endpoint
    | StrOutputParser()
)


from mlflow.models import infer_signature
import mlflow

model_name = f"some_model_name"
input_example = "input_example"
resp = chain.invoke(input_example)

with mlflow.start_run(run_name="run_name") as run:
    signature = infer_signature(input_example, resp)
    model_info = mlflow.langchain.log_model(
        chain,
        loader_fn=my_retriever, 
        artifact_path="path_to_artifact",
        registered_model_name=model_name,
        input_example=input_example,
        signature=signature
    )

 

 

4 REPLIES 4

Alberto_Umana
Databricks Employee
Databricks Employee

Hi @heramb13,

Can you try using this revised version of you code?

from databricks_langchain import DatabricksVectorSearch, ChatDatabricks
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnableMap, RunnableLambda
from langchain.schema.output_parser import StrOutputParser
from operator import itemgetter
import mlflow

vs_endpoint = "your_vector_search_endpoint"
my_index_name = "your_index_name"

def retriever_loader():
    my_index = DatabricksVectorSearch(
        endpoint=vs_endpoint,
        index_name=my_index_name,
        columns=["ID", "TEXT"]
    )
    return my_index.as_retriever(search_kwargs={"k": 3, "query_type": "HYBRID"})

my_retriever = retriever_loader()

prompt = PromptTemplate.from_template(
    template="""Some template: {query} and {context} """
)

def format_context(text):
    return modified(text)  # Ensure this function is defined

llm_endpoint = ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct")

chain = (
    RunnableMap({
        "query": RunnableLambda(itemgetter("messages")),
        "context": RunnableLambda(itemgetter("messages")) | my_retriever | RunnableLambda(format_context),
    })
    | prompt
    | llm_endpoint
    | StrOutputParser()
)

model_name = "some_model_name"
input_example = {"messages": "Your example query here"}
resp = chain.invoke(input_example)

with mlflow.start_run(run_name="run_name") as run:
    model_info = mlflow.langchain.log_model(
        chain,
        loader_fn=retriever_loader,
        artifact_path="path_to_artifact",
        registered_model_name=model_name,
        input_example=input_example
    )

 

Thank you so much for your response @Alberto_Umana !

The changes you mentioned about retriever worked and I don't have error with retriever anymore.

But I am still facing issues with llm_endpoint and the error is: 

 

MlflowException: Failed to save runnable sequence: {'2': "ChatDatabricks -- No module named 'langchain_databricks'"}.

 

My code is now exactly the same as you have mentioned in the post. 

thank you!

 

heramb13
New Contributor II

@Alberto_Umana or anyone? 

stbjelcevic
Databricks Employee
Databricks Employee

Hi @heramb13 ,

Your chain is being serialized by MLflow; during that process MLflow re-imports each runnable by its module path. The error shows MLflow is trying to import ChatDatabricks from the legacy module path “langchain_databricks”, which isn’t installed in the environment that’s saving the chain. This exact failure has been reported and typically occurs when the ChatDatabricks object (or its alias) resolves to the old path even if your source import line uses databricks_langchain. The fix is to either rebuild the chain so the class resolves to the new module path, or include the legacy shim as a dependency at log time.

Quick checks

Confirm ChatDatabricks resolves to the new module path with print(ChatDatabricks.module)

  • Ensure your input_example shape matches the first runnable in the chain (dict with “messages”, not a raw string)
  • Always pass pip_requirements with databricks-langchain (and optionally langchain-databricks if the legacy path appears)

References that should help: