For my RAG use case, I've registered my langchain chain as a model to Unity Catalog. When I'm trying to serve the model, container image creation fails with the following error in the build log:
[...]
#16 178.1 Downloading langchain_core-0.3.17-py3-none-any.whl (409 kB)
#16 178.1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 409.3/409.3 kB 48.1 MB/s eta 0:00:00
#16 178.1 Downloading langchain_text_splitters-0.3.2-py3-none-any.whl (25 kB)
#16 178.1 Downloading langsmith-0.1.142-py3-none-any.whl (306 kB)
#16 178.1 ━━━━━━━━━━━Pip subprocess error:
#16 178.1 error: subprocess-exited-with-error
#16 178.1
#16 178.1 × python setup.py bdist_wheel did not run successfully.
#16 178.1 │ exit code: 1
#16 178.1 ╰─> [23 lines of output]
#16 178.1 /opt/conda/envs/mlflow-env/lib/python3.11/site-packages/setuptools/_distutils/dist.py:261: UserWarning: Unknown distribution option: 'cffi_modules'
#16 178.1 warnings.warn(msg)
#16 178.1 running bdist_wheel
#16 178.1 running build
#16 178.1 running build_py
#16 178.1 creating build/lib.linux-x86_64-cpython-311/snappy
#16 178.1 copying src/snappy/snappy_formats.py -> build/lib.linux-x86_64-cpython-311/snappy
#16 178.1 copying src/snappy/hadoop_snappy.py -> build/lib.linux-x86_64-cpython-311/snappy
#16 178.1 copying src/snappy/snappy_cffi.py -> build/lib.linux-x86_64-cpython-311/snappy
#16 178.1 copying src/snappy/__main__.py -> build/lib.linux-x86_64-cpython-311/snappy
#16 178.1 copying src/snappy/snappy.py -> build/lib.linux-x86_64-cpython-311/snappy
#16 178.1 copying src/snappy/__init__.py -> build/lib.linux-x86_64-cpython-311/snappy
#16 178.1 copying src/snappy/snappy_cffi_builder.py -> build/lib.linux-x86_64-cpython-311/snappy
#16 178.1 running build_ext
#16 178.1 building 'snappy._snappy' extension
#16 178.1 creating build/temp.linux-x86_64-cpython-311/src/snappy
#16 178.1 gcc -pthread -B /opt/conda/envs/mlflow-env/compiler_compat -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /opt/conda/envs/mlflow-env/include -fPIC -O2 -isystem /opt/conda/envs/mlflow-env/include -fPIC -I/opt/conda/envs/mlflow-env/include/python3.11 -c src/snappy/crc32c.c -o build/temp.linux-x86_64-cpython-311/src/snappy/crc32c.o
#16 178.1 g++ -pthread -B /opt/conda/envs/mlflow-env/compiler_compat -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /opt/conda/envs/mlflow-env/include -fPIC -O2 -isystem /opt/conda/envs/mlflow-env/include -fPIC -I/opt/conda/envs/mlflow-env/include/python3.11 -c src/snappy/snappymodule.cc -o build/temp.linux-x86_64-cpython-311/src/snappy/snappymodule.o
#16 178.1 src/snappy/snappymodule.cc:33:10: fatal error: snappy-c.h: No such file or directory
#16 178.1 33 | #include <snappy-c.h>
#16 178.1 | ^~~~~~~~~~~~
#16 178.1 compilation terminated.
#16 178.1 error: command '/usr/bin/g++' failed with exit code 1
#16 178.1 [end of output]
#16 178.1
#16 178.1 note: This error originates from a subprocess, and is likely not a problem with pip.
#16 178.1 ERROR: Failed building wheel for python-snappy
#16 178.1 ERROR: Could not build wheels for python-snappy, which is required to install pyproject.toml-based projects
#16 178.1
#16 178.1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 306.7/306.7 kB 49.2 MB/s eta 0:00:00
#16 178.1 Downloading Markdown-3.7-py3-none-any.whl (106 kB)
#16 178.1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 106.3/106.3 kB 21.6 MB/s eta 0:00:00
[...]
The latest RAG version contains a cross encoder for reranking. Before the reranking was included, everything worked just fine. This is my chain code before:
import os
import mlflow
from operator import itemgetter
from databricks.vector_search.client import VectorSearchClient
from langchain_community.chat_models import ChatDatabricks
from langchain_community.vectorstores import DatabricksVectorSearch
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables import RunnableParallel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
# Enable MLflow Tracing
mlflow.langchain.autolog()
# Return the string contents of the most recent message from the user
def extract_user_query_string(chat_messages_array):
return chat_messages_array[-1]["content"]
# Get the conf from the local conf file
model_config = mlflow.models.ModelConfig(
development_config="chain_config.yaml"
)
databricks_resources = model_config.get("databricks_resources")
retriever_config = model_config.get("retriever_config")
llm_config = model_config.get("llm_config")
# Connect to the Vector Search Index
vsc = VectorSearchClient(disable_notice=True)
vs_index = vsc.get_index(
endpoint_name=databricks_resources.get("vector_search_endpoint_name"),
index_name=retriever_config.get("vector_search_index"),
)
vector_search_schema = retriever_config.get("schema")
# Turn the Vector Search index into a LangChain retriever
vector_search_as_retriever = DatabricksVectorSearch(
vs_index,
text_column=vector_search_schema.get("chunk_text"),
columns=[
vector_search_schema.get("primary_key"),
vector_search_schema.get("chunk_text"),
vector_search_schema.get("document_uri"),
],
).as_retriever(search_kwargs=retriever_config.get("parameters"))
# Required to:
# 1. Enable the RAG Studio Review App to properly display retrieved chunks
# 2. Enable evaluation suite to measure the retriever
mlflow.models.set_retriever_schema(
primary_key=vector_search_schema.get("primary_key"),
text_column=vector_search_schema.get("chunk_text"),
doc_uri=vector_search_schema.get("document_uri"),
)
# Method to format the docs returned by the retriever into the prompt
def format_context(docs):
sources = []
[
sources.append(
{"content": doc.page_content, "url": doc.metadata["url"]}
)
for doc in docs
]
return sources
# Prompt Template for generation
prompt = PromptTemplate(
template=llm_config.get("llm_prompt_template"),
input_variables=llm_config.get("llm_prompt_template_variables"),
)
# FM for generation
model = ChatDatabricks(
endpoint=databricks_resources.get("llm_endpoint_name"),
extra_params=llm_config.get("llm_parameters")
)
# RAG Chain
# The "|" syntax is powered by the LangChain Expression Language (LCEL)
# To learn more about LCEL, read the documentation: https://python.langchain.com/v0.1/docs/expression_language/
second_chain_part = prompt | model | StrOutputParser()
chain = RunnableParallel(
{
"question": itemgetter("messages")
| RunnableLambda(extract_user_query_string),
"history": itemgetter("messages") | RunnablePassthrough(),
"context": itemgetter("messages")
| RunnableLambda(extract_user_query_string)
| vector_search_as_retriever
| RunnableLambda(format_context),
}
).assign(answer=second_chain_part)
# Tell MLflow logging where to find your chain
mlflow.models.set_model(model=chain)
... and this is my chain code after:
import os
import mlflow
from operator import itemgetter
from databricks.vector_search.client import VectorSearchClient
from langchain_community.chat_models import ChatDatabricks
from langchain_community.vectorstores import DatabricksVectorSearch
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables import RunnableParallel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
# Enable MLflow Tracing
mlflow.langchain.autolog()
# Return the string contents of the most recent message from the user
def extract_user_query_string(chat_messages_array):
return chat_messages_array[-1]["content"]
# Get the conf from the local conf file
model_config = mlflow.models.ModelConfig(
development_config="chain_config.yaml"
)
databricks_resources = model_config.get("databricks_resources")
retriever_config = model_config.get("retriever_config")
llm_config = model_config.get("llm_config")
# Connect to the Vector Search Index
vsc = VectorSearchClient(disable_notice=True)
vs_index = vsc.get_index(
endpoint_name=databricks_resources.get("vector_search_endpoint_name"),
index_name=retriever_config.get("vector_search_index"),
)
vector_search_schema = retriever_config.get("schema")
# Turn the Vector Search index into a LangChain retriever
vector_search_as_retriever = DatabricksVectorSearch(
vs_index,
text_column=vector_search_schema.get("chunk_text"),
columns=[
vector_search_schema.get("primary_key"),
vector_search_schema.get("chunk_text"),
vector_search_schema.get("document_uri"),
],
).as_retriever(search_kwargs=retriever_config.get("parameters"))
# Required to:
# 1. Enable the RAG Studio Review App to properly display retrieved chunks
# 2. Enable evaluation suite to measure the retriever
mlflow.models.set_retriever_schema(
primary_key=vector_search_schema.get("primary_key"),
text_column=vector_search_schema.get("chunk_text"),
doc_uri=vector_search_schema.get("document_uri"),
)
# Load CrossEncoder model for reranking
reranking_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
compressor = CrossEncoderReranker(model=reranking_model, top_n=5)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=vector_search_as_retriever
)
# Method to format the docs returned by the retriever into the prompt
def format_context(docs):
sources = []
[
sources.append(
{"content": doc.page_content, "url": doc.metadata["url"]}
)
for doc in docs
]
return sources
# Prompt Template for generation
prompt = PromptTemplate(
template=llm_config.get("llm_prompt_template"),
input_variables=llm_config.get("llm_prompt_template_variables"),
)
# FM for generation
model = ChatDatabricks(
endpoint=databricks_resources.get("llm_endpoint_name"),
extra_params=llm_config.get("llm_parameters"),
)
# RAG Chain
# The "|" syntax is powered by the LangChain Expression Language (LCEL)
# To learn more about LCEL, read the documentation: https://python.langchain.com/v0.1/docs/expression_language/
second_chain_part = prompt | model | StrOutputParser()
chain = RunnableParallel(
{
"question": itemgetter("messages")
| RunnableLambda(extract_user_query_string),
"history": itemgetter("messages") | RunnablePassthrough(),
"context": itemgetter("messages")
| RunnableLambda(extract_user_query_string)
| compression_retriever
| RunnableLambda(format_context),
}
).assign(answer=second_chain_part)
# Tell MLflow logging where to find your chain
mlflow.models.set_model(model=chain)
I do realize that basically the same issue is already discussed here (https://community.databricks.com/t5/machine-learning/serving-endpoint-container-image-creation-fails...), but downgrading my ML cluster from DBR 15.4 LTS ML to DBR 14.3 LTS ML didn't solve the problem for me.