Showing results for 
Search instead for 
Did you mean: 
Technical Blog
Explore in-depth articles, tutorials, and insights on data analytics and machine learning in the Databricks Technical Blog. Stay updated on industry trends, best practices, and advanced techniques.
Showing results for 
Search instead for 
Did you mean: 
New Contributor II
New Contributor II


Since the release of ChatGPT in November 2022, interest in Generative AI (GenAI) has increased exponentially. Almost every company has identified an opportunity or a use case and is trying to leverage GenAI capabilities to become the next leader in its industry. By exploiting GenAI's potential, you may be the next one at the top of your market and take a real competitive advantage.


The GenAI boom has been very beneficial, as it has brought to light numerous use cases that can leverage Natural Language Processing (NLP) techniques. NLP represents a field in AI that consists of several machine learning algorithms that analyze and understand human language and generate text with content similar to what humans would do. GenAI is not a necessity for NLP. You can still gain significant results through NLP techniques that do not require large foundational models associated with GenAI use cases. These use cases can be developed quickly with just a few lines of code at a really low cost while still reaching a fair level of performance and leverage open source models.

In this article, we’ll explore some techniques that will enable you to develop popular NLP tasks such as text classification, sentiment analysis, and translation, using large language models (LLMs), with tens or hundreds of millions of parameters. The best part is you do not even need a GPU. If you have access to one GPU, it will most probably accelerate the execution, but it’s not mandatory. Additionally, you don’t need to install any library as everything is already installed into your Databricks ML Runtime. The MLflow components and the transformers library from HuggingFace will be the main ingredients of your recipe.

Table of Contents


Text Classification with NLP

A common task in NLP is text classification - where we assign labels to our text. This improves the evaluation and post-analysis steps, by allowing us to organize our raw texts into a set of metadata categories. Attaching these metadata to our text significantly facilitates key information extraction. Several methods to help accomplish this task include zero-shot classification or few-shot classification.

Zero-shot learning for text classification

Zero-shot Learning (ZSL) refers to the task of predicting a class that wasn't seen by the model during training. It’s fair to note that the concept first appeared in 2008, fifteen years before ChatGPT. It has been intensively used in computer vision and then widely adopted in NLP. A common approach to ZSL includes computing embeddings and determining the semantic similarity between two embeddings.

An alternative method to this, which has been used in the below example, is based on Natural Language Inference (NLI). This technique determines the compatibility of the two distinct sequences. Here, the input text will be semantically compared with each of the candidate labels, one by oneFor these comparisons, there is a starting hypothesis: that the label and the text are similar. The pipeline subsequently will determine if this hypothesis is true, false, or neutral. If the hypothesis is proven to be true, the label is relevant to the input text.

Since we must compare all the candidate labels with the input text, the number of these comparisons increases linearly with the number of candidates, so we can quickly hit performance issues. However, the advantage resides in the method itself. All those comparisons are done during the inference. There is no training or fine-tuning specific to our task. There is no dependency between the model and the list of candidate labels. This is a huge advantage! If our use case evolves, or if we need to add new labels, remove or modify some of them, there is no impact on the model.

While prompt engineering enables the use of LLMs, comparable results can be attained with these two methods designed for the classification task. These techniques are particularly beneficial for the risk-averse or those encountering challenges with more advanced methods, offering a quick initial iteration, feedback collection, and establishing solid foundations for further exploration. Lastly, both of them rely on smaller LLMs, reducing infrastructure requirements and costs.  

Implementing ZSL with the Databricks ML Runtime

Now, let’s look at how we can easily use such a model. We will leverage the transformers library from HuggingFace, which has already been installed in your Databricks ML Runtime. We will also use PyTorch to determine if there is a GPU available. The same piece of code can be used in both cases. You can expect better performance using a single GPU (see below for a deeper analysis).

We will leverage the transformers library and the models' hub from HuggingFace. The bart-large-mnli model has been made available by Facebook, applying the NLI method to the pre-trained BART model.


from transformers import pipeline
import torch

# Determine if there is GPU or not
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the zeroshot pipeline
model = "facebook/bart-large-mnli"
task = "zero-shot-classification"
zshot = pipeline(task=task, model=model)

# Inference
candidate_labels = ['sports',
  'pop culture',
  'breaking news',
  'science and technology',

inference_config = {'candidate_labels': candidate_labels, 'multi_label': False}
input_text = "Zinedine Zidane is the GOAT french football player"
pred_label = zeroshot(input_text, candidate_labels, device=device)


Register and log your Model with Unity Catalog

Now, we can register our model within Databricks Unity Catalog. By doing so, you will be able to centrally manage the entire lifecycle of your model, including access controls, lineage, and model discovery. You just need to define the catalog and the schema to register the model to.


import mlflow

# Set unity catalog
catalog = "dev"
schema = "classif"

# Log the model
model_name = f"{catalog}.{schema}.zeroshot_model"


Few-shot learning classification

Thanks to zero-shot learning, you've successfully deployed your initial classification model into production within a short timeframe, suggesting prompt solutions for your business end-users. As you begin to receive their feedback,  it's evident that while the model is beneficial, some misclassifications has occurred during testing, as expected.

You have been asked to improve your existing model, but budget and time constraints remain. The few-shot classification can help you here. In contrast to the previous solution, we will adapt the existing model with your data. However, the power of this method lies in the fact that you only need a few examples per label to fine-tune the model - some frameworks might even need less than 10 samples per label! This means that we are far from the extensive data requirements typically associated with training or fine-tuning LLMs. 

In this example, if we break down the few-shot model, we'll discover two 'sub-models'.

  • The first part is a sentence transformer. It is based on an already trained model and often derived from BERT (such as RoBERTa or DistilBERT for example), which has been trained on a substantial corpus of 800M words (Book Corpus) and 2500M words (English Wikipedia). The objective of the fine-tuning process is to update the weights of such a model slightly to have a better representation of your input data. This sentence transformer will encode our raw input texts and create embeddings, serving as inputs for the second part of the few-shot model, which is the classification head.
  • This classification head will be fitted using your training data. Going further with SetFit, the two steps (1- Fine tuning of a sentence transformer, 2 - training of a classification head) are packed together and are transparent for the user.

In the end, you end up training a few-shot model.



Two choices are possible between the sub-models. We could have an additional layer on top of the architecture of the sentence transformer or a completely different model. A library such as SetFit allows you to choose between those two approaches depending on your needs. For example, you can use a traditional logistic regression model from scikit-learn as the classification head. It can be an advantage if you want to leverage the ecosystem around scikit-learn and dig deeper into the explainability.

Prepare your data

The first step is to prepare your data. As mentioned earlier, the method can work even with less than ten examples per label. So, even if you have 50 different labels, a dataset of 500 rows would be enough. The training time will be less than 5 minutes with a single CPU node of 32 GB!

There is no magic number or hard limit. You can also increase the number of examples per label and evaluate if your model is more accurate. It’s a balance between the training time and the overall performance. But if you can expect a real improvement between 5 and 10 examples per label, there is no guarantee that more means better beyond, depending on how representative of your use case they are and the data you will have during inference.

Testing the performance with 2 or 3 examples per label is also a good exercise. Is it better than zero-shot learning? What’s the improvement? Again, the training time is really low so that you can explore the behaviors. Beyond the number of training samples, their quality is key. 


input_text_col = 'raw_text'
output_label_col = 'category'

df ='/Volumes/{catalog}/{schema}/landing_zone/sample_classif_data.csv', header=True, sep=';') \
          .withColumnRenamed(input_text_col, 'text') \
          .withColumnRenamed(output_label_col, 'label') \
          .select('text', 'label') \


As we train a model, we have to provide a training and test dataset, so we split our data. 


# Split train-test
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

# Convert to dataset (expected input format)
from datasets import Dataset, DatasetDict
train_ds = Dataset.from_pandas(train_df)
test_ds = Dataset.from_pandas(train_df)

ds = DatasetDict()
ds['train'] = train_ds
ds['test'] = test_ds


Note: the best practice would be to add a validation dataset as well, which is not performed here to simplify the process

Train your model

We can start training the model now that we have our train and test datasets. You will have to provide the sentence transformer, which will be the foundation of your few-shot model, and a few parameters, such as the batch size and the number of epochs. One epoch is often the default choice, as we have only a few input rows. As previously mentioned, the sentence transformer has already been pretrained, and we want to benefit from it. If we increase the number of epochs, it may have a negative effect on the performance.


from setfit import SetFitModel, Trainer, TrainingArguments

# Load sentence transformers from HugginFace
hf_pretained_model = "sentence-transformers/paraphrase-mpnet-base-v2"
model = SetFitModel.from_pretrained(hf_pretained_model).to(device=device)

# Set training arguments
training_args = TrainingArguments(

# Create trainer
trainer = Trainer(

# Train


Evaluate your model

The choice of the sentence transformer is essential and will impact the process's overall performance. Try several sentence transformers to evaluate the impact on your final model. A criteria could be the language of the corpus used to train the transformer. However, even if the open source community adds more and more models, there are still many more models in English today. Depending on your use case and your language, it may be a good idea to give a try to an English model and a good starting point to benchmark different models.


# Evaluate
metrics = trainer.evaluate()


Log the model

We first need to save locally the model in a temporary location.


# First, save the finetuned model locally


Then, we have to define a custom model with MLflow. Remember that we have two “sub-models”, the sentence transformer and the classifier, which is not the standard expected by MLflow. We could imagine storing each sub-model independently using their respective flavor (transformer and sklearn). Still, it would be less straightforward to infer at the end as we first need to encode our text and then provide the embeddings to the classification head to be classified.


from mlflow.pyfunc import PythonModel

class SetFitCustomModel(PythonModel):
  def load_context(self, context):
    self.model = SetFitModel.from_pretrained(context.artifacts['snapshot'])

  def predict(self, context, inputs):
    return self.model(inputs['prompt'])


The next step is to create the signature of our model. The model takes a string as input and produces another string (corresponding to the label) as the output. 


from mlflow.models.signature import ModelSignature
from mlflow.types import DataType, Schema, ColSpec

# Define input and output schema
input_schema = Schema([ColSpec(DataType.string, "prompt"),])
output_schema = Schema([ColSpec(DataType.string, "text")])

#Define signature
signature = ModelSignature(inputs=input_schema, outputs=output_schema)


Now, the model can be registered within Unity Catalog using MLflow. As we have a custom model, we use the pyfunc flavor.



with mlflow.start_run() as run:

  model_details = mlflow.pyfunc.log_model(
    artifacts={'snapshot': 'snapshot'},


Performing ML Inference

The next step is to deploy and serve your model to an endpoint so the model can be used for inference in real-time. We will use MLflow deployment to deploy to the Model Serving of Databricks. 

Serve the model with Model Serving

The paraphrase-mpnet-base-v2 model we used above has +100M of parameters, far away from the 70B of LLama2 or 170B of GPT3. This means that the endpoint can be smaller and still have a low latency inference and a reduced cost.


from mlflow.deployments import get_deploy_client

client = get_deploy_client("databricks")

endpoint_name = "fewshot_setfit"
version = "1"

endpoint = client.create_endpoint(
        "served_entities": [
                "entity_name": f"{catalog}.{schema}.{endpoint_name}",
                "entity_version": version,
                "workload_size": "Small",
                "scale_to_zero_enabled": "true"
        "traffic_config": {
            "routes": [
                    "served_model_name": f"{model_name}-{version}", 
                    "traffic_percentage": 100


Distribute your inference with Spark

If the volume of your data increases, you will need to scale out the inference process. Using Spark and the Pandas UDFs, we can distribute the inference to all the available workers in your cluster. It’s fully compatible with the GPU, as we are not operating at the same level. We first use the pandas UDF to distribute the workload over several nodes. Then, each node will independently leverage its own GPU, if any. In short, you can combine the benefits of both the cluster distribution from Spark and the hardware behind the GPU with PyTorch.

In the following example, we use an Iterator to Iterator pandas UDF to load the model only once per node and increase the overall performance.


import pyspark.sql.types as T
from typing import Iterator
from pyspark.sql.functions import pandas_udf, col

def predict_iterator(series: Iterator[pd.Series]) -> Iterator[pd.Series]:
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  zshot = pipeline(

  for s in series:
      results = zshot(s.to_list(), candidate_labels=[
                "science and technology",
                "pop culture",
                "breaking news",
            ], multi_label=False)
      output = [result['labels'][0] for result in results]
      yield pd.Series(output)

display('texts', predict_iterator(col('texts'))))


Query the model into SQL with the AI Functions

This endpoint can be directly used in SQL, so data analysts and analytic engineers can use it in their analyses and queries. It fills the gap between data scientists and analysts and increases overall collaboration within the teams.


SELECT text_col, ai_query("fewshot-endpoint",
    returnType => "STRING"



When starting to use a GPU, most people think it will cost more than a CPU. We have performed a benchmark using the previous few-shot classification model to compare the inference cost properly. The results are clear: GPU can accelerate the inference time and reduce the overall cost simultaneously. Splitting a dataset per batch of 500 rows, the inference time is reduced by 37x and the cost is reduced by a factor of 16! This comparison has been performed for a 500k rows dataset and 5 candidate labels. The GPU was a g5.2xlarge node and the CPU was a m5.xlarge node.

qtldb_1-1715853268667.png qtldb_3-1715853311795.png

Thanks to the
ML Runtime within Databricks, you can switch between CPU and GPU without changing your configuration. The runtime handles everything, and you can seamlessly transition to GPU.


In this article, we have explored two simple yet efficient classification techniques for implementing LLMs in Databricks. Zero-shot learning is the most straightforward approach without needing to train your model. You can leverage existing open source models and quickly deploy a zero-shot pipeline in production. Additionally, we have also seen the efficiency of the few-shot model being trained with only a few examples per label, utilizing open source models such as BERT. This approach tailors the model to your specific data, with the significant advantage of operating with a small portion of the dataset, thus minimizing training costs and complexity. Those two methods allow you to deploy models in production quickly, eliminating the uncertainties that slow down a project. If you have difficulties forecasting the ROI of a use case, these can be valuable methods to start before jumping into larger LLMs.

Most importantly, we have seen that whichever technique you choose to build your model and deploy it into production, this journey would easily be achieved in Databricks. Leveraging the ready environment provided by the Databricks ML Runtime, the central governance solution of Unity Catalog, MLflow components with Model Serving and the AI functions will accelerate your path into production.