cancel
Showing results for 
Search instead for 
Did you mean: 
Machine Learning
Dive into the world of machine learning on the Databricks platform. Explore discussions on algorithms, model training, deployment, and more. Connect with ML enthusiasts and experts.
cancel
Showing results for 
Search instead for 
Did you mean: 

Errors using Dolly Deployed as a REST API

GKH
New Contributor II

We have deployed Dolly (https://huggingface.co/databricks/dolly-v2-3b) as a REST API endpoint on our infrastructure. The notebook we used to do this is included in the text below my question.

The Databricks infra used had the following config -  (13.2 ML, GPU, Spark 3.4.0, g5.2xlarge) .

Dolly executes perfectly in-notebook, without any issues. We created two chains in Langchain to test execution. The first was a Vanilla chain that can be used to answer questions directly, with no context provided. The second is a contextual Q&A chain. Both worked perfectly.

The creation of the model, registration and deployment itself proceeds smoothly, without any issues although it does take a long time (~20+ minutes sometimes...)

Problems arise when we try to access the model indirectly through either the REST interface or by loading a logged and registered model using its URI. I've uploaded the error we see in the attached image.

From here, we tried a variety of things to try and debug the error and see if we could fix it by ourselves, but to no avail. We have tried changing the input format, passing lists instead of strings, Dataframes instead of strings, changing runtime versions, changing the way we log the model (using mlflow.pyfunc.log_model instead of mlflow.langchain.log_model), and experimenting with a variety of JSON formats consistent with the MLflow documentation for JSON inputs to model-served REST APIs.

In all cases, we get this error. From our debugging attempts, it appears that the prompt that is being formed is somehow returning None, but explicitly including prompts as an argument once a model has been logged as a Langchain model is not allowed (in other words the input schema is pre-decided when the model is compiled and inputs need certain keywords that are part of the prompt).

We've spent a lot of time and GPU cycles trying to get this rather straightforward use case to work. Does anyone in this community have any insight into what we might be doing wrong here?

Any help would be greatly appreciated!

Thanks in Advance!


-----------------begin notebook code-----------------------

# Databricks notebook source
# MAGIC %pip install -U mlflow langchain==0.0.164 transformers numpy==1.24.4 sqlalchemy==2.0.17

# COMMAND ----------

dbutils.library.restartPython()

# COMMAND ----------

import torch
from transformers import pipeline

generate_text = pipeline(model="databricks/dolly-v2-3b", torch_dtype=torch.bfloat16,
                         trust_remote_code=True, device_map="auto", return_full_text=True)

# COMMAND ----------

from langchain import PromptTemplate, LLMChain
from langchain.llms import HuggingFacePipeline

# template for an instrution with no input
prompt = PromptTemplate(
    input_variables=["instruction"],
    template="{instruction}")

# template for an instruction with input
prompt_with_context = PromptTemplate(
    input_variables=["instruction", "context"],
    template="{instruction}\n\nInput:\n{context}")

hf_pipeline = HuggingFacePipeline(pipeline=generate_text)

llm_chain = LLMChain(llm=hf_pipeline, prompt=prompt)
llm_context_chain = LLMChain(llm=hf_pipeline, prompt=prompt_with_context)

# COMMAND ----------

print(llm_chain.predict(instruction="Explain to me the difference between nuclear fission and fusion.").lstrip())

# COMMAND ----------

context = """George Washington (February 22, 1732[b] - December 14, 1799) was an American military officer, statesman,
and Founding Father who served as the first president of the United States from 1789 to 1797."""

print(llm_context_chain.predict(instruction="When was George Washington president?", context=context).lstrip())

# COMMAND ----------



# COMMAND ----------

import mlflow
import pandas as pd
from json import dumps
def publish_model_to_mlflow(llm_chain,chain_name😞
 
  with mlflow.start_run() as run:
      # Save model to MLFlow
      # Note that this only saves the langchain pipeline (we could also add the ChatBot with a custom Model Wrapper class)
      # The vector database lives outside of your model
      mlflow.langchain.log_model(llm_chain, artifact_path="ketos-gpt-"+chain_name)
      model_registered = mlflow.register_model(f"runs:/{run.info.run_id}/gpt-{chain_name}", "gpt-"+chain_name)
 
  # Move the model in production
  client = mlflow.tracking.MlflowClient()
  print("registering model version "+model_registered.version+" as production model")
  client.transition_model_version_stage("gpt-"+chain_name, model_registered.version, stage = "Production", archive_existing_versions=True)
 
def load_model_and_answer(question, model_uri😞
  # Note: this will load the model once more in memory
  # Load the langchain pipeline & run inferences
  chain = mlflow.pyfunc.load_model(model_uri)
  print(type(chain))
  #chain = mlflow.langchain.load_model(model_uri)
  chain.predict([{'instruction': question}])

# COMMAND ----------

publish_model_to_mlflow(llm_chain,'dolly-3b-vanilla')

# COMMAND ----------

question = "which continent is India located in?"
chain_name = 'dolly-3b-vanilla'
model_name = "gpt-"+chain_name
model_version = 6
model_uri = f"models:/{model_name}/{model_version}"
load_model_and_answer(question=question, model_uri=model_uri)

 

 

1 REPLY 1

marcelo2108
Contributor

I had a similar problem when I used HuggingFacePipeline(pipeline=generate_text) with langchain. It worked to me when I tried to use HuggingFaceHub instead. I used the same dolly-3b model.

Join 100K+ Data Experts: Register Now & Grow with Us!

Excited to expand your horizons with us? Click here to Register and begin your journey to success!

Already a member? Login and join your local regional user group! If there isn’t one near you, fill out this form and we’ll create one for you to join!