cancel
Showing results forย 
Search instead forย 
Did you mean:ย 
Generative AI
Explore discussions on generative artificial intelligence techniques and applications within the Databricks Community. Share ideas, challenges, and breakthroughs in this cutting-edge field.
cancel
Showing results forย 
Search instead forย 
Did you mean:ย 

Unable to log MLFlow run for LangChain chain while using databricks-langchain library

heramb13
New Contributor II

Whenever I try to log my run it throws me the following error:

 

MlflowException: Failed to save runnable sequence: {'0': 'RunnableParallel<query,context> -- Failed to save runnable sequence: {\'context\': "RunnableSequence -- Failed to save runnable sequence: {\'2\': \'VectorStoreRetriever -- The `loader_fn` must be a function that returns a retriever.\'}."}.', '2': "ChatDatabricks -- 1 validation error for ChatDatabricks\nendpoint\n  Field required [type=missing, input_value={'extra_params': {}, 'max...ks', 'temperature': 0.0}, input_type=dict]\n    For further information visit https://errors.pydantic.dev/2.10/v/missing"}.

 

I am not sure what exactly is missing here. Initially I was using parameters such as temperature, max_tokens but decided to remove those as well. Yet the error was there.

  • my_retriever is basically retriever using DatabricksVectorSearch.
  • There are no additional parameters to ChatDatabricks. 

I am not sure how do I fix this. 

Can someone help me with this?

Here is my simple code:

 

from databricks_langchain import DatabricksVectorSearch, ChatDatabricks
# excluding all langchain imports

my_index = DatabricksVectorSearch(
    endpoint=vs_endpoint,
    index_name=my_index_name,
    columns=["ID", "TEXT"]
    )
my_retriever = my_index.as_retriever(search_kwargs={"k": 3, "query_type": "HYBRID"})

prompt = PromptTemplate.from_template(
    template="""Some template: {query} and {context} """
)
def format_context(text):
    return modified(text)
llm_endpoint = ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct")

chain = (
    RunnableMap({
        "query": RunnableLambda(itemgetter("messages")),
        "context": RunnableLambda(itemgetter("messages")) | my_retriever | RunnableLambda(format_context),
    })
    | prompt
    | llm_endpoint
    | StrOutputParser()
)


from mlflow.models import infer_signature
import mlflow

model_name = f"some_model_name"
input_example = "input_example"
resp = chain.invoke(input_example)

with mlflow.start_run(run_name="run_name") as run:
    signature = infer_signature(input_example, resp)
    model_info = mlflow.langchain.log_model(
        chain,
        loader_fn=my_retriever, 
        artifact_path="path_to_artifact",
        registered_model_name=model_name,
        input_example=input_example,
        signature=signature
    )

 

 

3 REPLIES 3

Alberto_Umana
Databricks Employee
Databricks Employee

Hi @heramb13,

Can you try using this revised version of you code?

from databricks_langchain import DatabricksVectorSearch, ChatDatabricks
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnableMap, RunnableLambda
from langchain.schema.output_parser import StrOutputParser
from operator import itemgetter
import mlflow

vs_endpoint = "your_vector_search_endpoint"
my_index_name = "your_index_name"

def retriever_loader():
    my_index = DatabricksVectorSearch(
        endpoint=vs_endpoint,
        index_name=my_index_name,
        columns=["ID", "TEXT"]
    )
    return my_index.as_retriever(search_kwargs={"k": 3, "query_type": "HYBRID"})

my_retriever = retriever_loader()

prompt = PromptTemplate.from_template(
    template="""Some template: {query} and {context} """
)

def format_context(text):
    return modified(text)  # Ensure this function is defined

llm_endpoint = ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct")

chain = (
    RunnableMap({
        "query": RunnableLambda(itemgetter("messages")),
        "context": RunnableLambda(itemgetter("messages")) | my_retriever | RunnableLambda(format_context),
    })
    | prompt
    | llm_endpoint
    | StrOutputParser()
)

model_name = "some_model_name"
input_example = {"messages": "Your example query here"}
resp = chain.invoke(input_example)

with mlflow.start_run(run_name="run_name") as run:
    model_info = mlflow.langchain.log_model(
        chain,
        loader_fn=retriever_loader,
        artifact_path="path_to_artifact",
        registered_model_name=model_name,
        input_example=input_example
    )

 

Thank you so much for your response @Alberto_Umana !

The changes you mentioned about retriever worked and I don't have error with retriever anymore.

But I am still facing issues with llm_endpoint and the error is: 

 

MlflowException: Failed to save runnable sequence: {'2': "ChatDatabricks -- No module named 'langchain_databricks'"}.

 

My code is now exactly the same as you have mentioned in the post. 

thank you!

 

heramb13
New Contributor II

@Alberto_Umana or anyone? 

Connect with Databricks Users in Your Area

Join a Regional User Group to connect with local Databricks users. Events will be happening in your city, and you wonโ€™t want to miss the chance to attend and share knowledge.

If there isnโ€™t a group near you, start one and help create a community that brings people together.

Request a New Group