cancel
Showing results for 
Search instead for 
Did you mean: 
Technical Blog
Explore in-depth articles, tutorials, and insights on data analytics and machine learning in the Databricks Technical Blog. Stay updated on industry trends, best practices, and advanced techniques.
cancel
Showing results for 
Search instead for 
Did you mean: 
owbenita
Databricks Employee
Databricks Employee

In the era of generative AI and large language models (LLMs), utilizing domain and business data has become crucial for tailoring models to achieve differentiation and enhance personalization. For more advanced customizations, Foundational Model Fine-tuning is a popular choice when using a base pretrained foundational model with prompt engineering or RAG is not providing the results needed, you want to retain maintain complete control over the fine-tuned model and its deploy

Synthesizing Data in Banking

To help understand how this works, we will walk through an example scenario using a fictional bank that is improving customer experience with an AI-powered chatbot. They need to ensure that they are providing financial guidance without giving direct investment advice. 

 

One significant challenge they experienced relying on prompt-engineering techniques, including those used in RAG, is that the system prompts became prohibitively long to incorporate all company domain rules and tone-of-voice consistently, thus significantly increasing the cost of invoking the model. 

Data Preparation

First, curate a qualty and representative dataset of what your users may ask and scope the model to a specific set of specialized topics, including in-scope topics and out-of-scope topics. The goal is to ensure the model is able to respond appropriately even on topics that are deemed risky or inappropriate. For example:

ment, or you need to improve the overall price-performance by using a smaller specialized model. 

What is Fine-Tuning?

Fine-tuning is a process that uses supervised learning to train a pretrained Large Language Model (LLM) to improve the performance and quality of the responses for a specific use case.  Supervised fine-tuning can be data-intensive, often requiring thousands of examples to significantly specialize a model for complex tasks. However, manually building a good quality, diverse, and relevant dataset for fine-tuning is often a resource-intensive task and a challenging effort for many organisations.  Databricks supports both Chat completion and Instruction fine-tuning

Synthetic Data for Fine-Tuning

One effective strategy for fine-tuning involves using an LLM as a "teacher" to produce synthetic data, which is then used to fine-tune a smaller "student" model. For example, Meta's Llama 3.1- 405B, a state-of-the-art LLM, can generate synthetic datasets for fine-tuning smaller models like Llama 3.1- 70B. This process, known as knowledge distillation, transfers the teacher model's expertise to the student model, making it more efficient and task-specific.

 

In this blog, we will walk through approaches and best practices for generating synthetic data for fine-tuning and how Databricks helps to simplify this process, including Batch Inference on Serverless Model Serving. Databricks supports three types of fine-tuning tasks: Chat Completion, Instruction fine-tuning, and continued pre-training; learn more about these different tasks in the documentation. We will focus on an example of a fictional bank, specifically for chat completion fine-tuning.

Architecture

The architecture below represents an end-to-end solution that comprises three main steps:

  1. Data Generation - High quality human-curated dataset is used as a seed to generate a larger dataset using an LLM via a databricks job.
  2. Training - Databricks Jobs are used for fine-tuning the base model using Mosaic AI's fine-tuning APIs using the combined dataset of human-curated and LLM-generated examples.
  3. Serving and Evaluation - Serve the trained model using Model Serving in databricks with Mosaic AI Evaluation and LLM Judges. Further sending the evaluation metrics as feedback to data generation for improving quality.

owbenita_0-1750228013321.png

Synthesizing Data in Banking

To help understand how this works, we will walk through an example scenario using a fictional bank that is improving customer experience with an AI-powered chatbot. They need to ensure that they are providing financial guidance without giving direct investment advice. 

One significant challenge they experienced relying on prompt-engineering techniques, including those used in RAG, is that the system prompts became prohibitively long to incorporate all company domain rules and tone-of-voice consistently, thus significantly increasing the cost of invoking the model. 

Data Preparation

First, curate a quality and representative dataset of what your users may ask and scope the model to a specific set of specialized topics, including in-scope topics and out-of-scope topics. The goal is to ensure the model is able to respond appropriately even on topics that are deemed risky or inappropriate. For example:

 

topic

question

in_scope

Account Management

How do I set up a direct deposit?

Yes

Investment Advice

Should I buy Apple stock?

No

Synthetic Data Generation

Step 1: Seeded Prompt

In this step, we provide a well-written seeded prompt for data generation. One best practice will be to use a prompt structure that encourages the model to engage in multi-stage reasoning, such as thinking step-by-step or considering various constraints sequentially to generate comprehensive and relevant responses.
For instance, the prompt might instruct the teacher LLM to first identify the user's intent, then recall relevant bank policies related to that intent, and finally, formulate a response that is helpful, compliant, and maintains the bank's tone-of-voice.

QUESTION_PROMPT = {
  "system": """
You are a knowledgable assistant for a banking app. Your role is to generate relevant questions inspired by provided examples. To ensure the questions are well structured and diverse, follow step-by-step reasoning before generating final output Format your thought process clearly, separating analysis from final outputs""",

"user": """
Generate {n} new questions for our banking app, follow these steps:
1. First analyze the example question and category
	- Category: {example_category}
	- Question: {example_question}
2. List potential subtopics
3. Consider how to frame them as relevant questions
4. Evaluate if the topics are diverse
5. Keep questions short and simplistic 
6. Show your thought process before generating the final output
"""
}

 

Step 2: Prompt Expansion

Prompt expansion refers to the process of dynamically evaluating and substituting variables, commands, or special sequences in shell prompts or text. Once we have a quality prompt, we need to ensure that the outputs are structured to match our table schema. To make this task easier and flexible, we use the open–source library for data validation Pydantic BaseModel, to help generate a structured object as we call our LLM.

 

from pydantic import BaseModel
from openai import OpenAI

#Prepares a list of chat messages in the correct format needed for LLMs
def get_chat_messages(
  prompt: dict[str, str],
  user_prompt_args: dict[str, str]
) -> list[dict[str, str]]:
  messages = []
  user_prompt = prompt["user"].format(**user_prompt_args)
  messages.append({"role": "user", "content": user_prompt})
  
  return messages 

#Generates a structured output from the LLM using Pydantic models for validation
def generate_structured_object(
  prompt: dict[str,str], 
	prompt_args: dict[str,any] | None, 
	base_model: type[BaseModel], 
	model: str,
	credentials: dict[str, dict[str,str]]| None,
	repeat_count: int = 2 
):
  messages = get_chat_messages(prompt, prompt_args)
  _creds = credentials[model_name]
  client = OpenAI( 				#change to call databricks llama model
		api_key = _creds["api_key"],
		base_url = _creds["base_url"]
	)
  for _ in range(repeat_count):
    try:
      completion = client.beta.chat.completions.parse(
        model = model,
        messages=messages,
        response_format=base_model
      )
      return completion.choices[0].message.parsed
    except Exception as e:
      print(f"Error: {e}")

 We can then create a function to pass the arguments to our function call to a model served on Databricks. For this task, we recommend using a scale-to-zero enabled provisioned throughput endpoint to avoid hitting any throughput limits. Provisioned throughput endpoints help to ensure optimized inference with performance guarantees for production workloads. However, for development purposes, pay-per-token is a more cost-effective solution

class PromptModel(BaseModel):
	reasoning: str
	questions: list[str]

# Calls the generarate_structured_object and passes required arguments
def generate_one_prompt(
  example_question: str,
  example_category: str,
  in_scope: str,
	credentials: dict[str, dict[str,str]]| None,
	model: str,
	num_questions: int = 1,
  prompt: dict[str,str] | None = None
) -> list[str]:
  prompt_args = {}
  prompt_args["n"] = num_questions 
  prompt_args["example_question"] = example_question
  prompt_args["example_category"] = example_category
  prompt_args["in_scope"] = in_scope

  response: PromptModel = generate_structured_object(
		prompt = prompt,
		prompt_args = prompt_args,
		base_model = PromptModel,
		model = model,
		credentials = credentials
  )	
  questions = [q for q in response.questions if q and len(q.strip()) > 0]
  return questions

The below function refers to a higher-level logic that will take a batch of inputs and distribute the calls to the LLM endpoint concurrently. This is essential for speeding up the generation of a large dataset.

from multiprocessing.pool import ThreadPool

def apply_function_parallel(
    func: callable, data: list[dict[str, any]], n_processes
) -> list[dict[str, any]]:
    with ThreadPool(processes=n_processes) as pool:
        result = pool.map(func, data)
    return result

def generate_n_prompts(
records: list[dict[str, any]],credentials: dict[str, dict[str, str]],model: str = "databricks-meta-llama-3-3-405b-instruct",number_of_questions: int, n_processes: int
)-> list[dict[str,any]]:

def _process(record: dict[str, any]) -> list[dict[str, any]]:
		questions = generate_one_prompt(record["Message"],
record["Topic"], credentials, model, num_questions = num_questions,prompt = QUESTION_PROMPT)
output = [{"question":q} for q in questions] 

	return output
return list( itertools.chain(*apply_function_parallel(_process, records, n_processes)))
)

Finally, we simply use Apache Spark APIs to read from the delta table that holds our sample questions and then call the function to generate prompts in parallel, further writing the results back to a delta table.

from pyspark.sql import DataFrame
import pandas as pd

input_table_path = '<your_catalog.your_schema.seed_examples>'
output_table_path = '<your_catalog.your_schema.synth_examples>'
serving_endpoint = "<your_endpoint>"
api_token = "<your_api_token>"
model_name = "<model_name"
num_questions = 10
processes = 2
generation_kwards = {"temperature": 0.99, "max_tokens": 200 }

#read data from Spark
dataframe = spark.sql(f"SELECT * FROM {input_table_path}")

records = dataframe.toPandas().to_dict(orient="records")

question_dataframe = generate_n_prompts(
  records= records,
  model = model_name,
  num_questions= num_questions,
  n_processes = processes,
  credentials={
    f"{model_name}": {
                "base_url": serving_endpoint,
                "api_key": api_token,
            }
  },
)

display(question_dataframe)
final_df = spark.createDataFrame(pd.DataFrame(question_dataframe))

Step 3: Response Generation with Batch Inference


Now that we have the questions and prompts, we will use Databricks AI functions for batch inference. This simplifies a number of tasks we performed above, such as creating structured outputs, schema validation, and parallel processing.

We recommend decoupling the generation of prompts from the responses, as this will help the model to focus on what kind of questions and tasks it needs to perform at each step for better quality. It also allows you to tune each part of the pipeline independently to achieve more realistic outputs.  

In this step, we create an AI query to produce AI generated responses in parallel across the synthetic questions within our table.

synth_responses = "<your_catalog.your_schema.synth_responses>"


QUERY = f""
  CREATE OR REPLACE TABLE {synth_responses} 
  AS 
          SELECT
          question,
          ai_query(
            '{model_name}',
            'You are a knowledgable expert for a banking app. Your role is to generate a well-formed response to the provided question. To ensure the answers are safe and adhere to the banking apps policies and tone-of-voice, follow the following guidelines:
            {rules_text}
            ##Reasoning
            Use step-by-step reasoning before generating final output. Format your thought process clearly, separating analysis from final outputs
            Generate a response  for our banking app, follow these steps:
            1. First analyze the question  and what the user is asking for
            Input:'  || question || 
            '
            2. Analyze if the question is in scope for the banking app
            3. Generate an initial question that follows each rule
            4. Analyze if your response is safe and adheres to each of the rules
            5. Generate a final response, modifying your response as needed
            5. Keep responses straight to the point
            7. Only generate final answer
            ',
            responseFormat => 'STRUCT<content:STRUCT<content:STRING>>',
            modelParameters => named_struct('temperature', 0.7, 'max_tokens', 512)
          ) AS response
          FROM {output_table_path}       
""

synth_generation = spark.sql(QUERY)

Step 4: Chat Completion Formatting

For chat completion fine-tuning, chat-formatted data must be in a .jsonl file, where each line is a separate JSON object representing a single chat session. Each chat session is represented as a JSON object with a single key, messages, that maps to an array of message objects.

We create a table that represents the format required for chat completion inside a delta table;

CREATE OR REPLACE TABLE <your_catalog.your_schema.formatted_synth_responses>
AS
SELECT to_json(named_struct(
  'messages', array(
    named_struct('role', 'user', 'content', question),
    named_struct('role', ' assistant', 'content', content)
  )
)) AS messages
FROM (
  SELECT question, from_json(response, 'STRUCT<content:STRING>').content AS content
  FROM owbenita_catalog.synth_blog_dataset.synth_responses
)

  We can then store this as a .jsonl file in a databricks Volume once we are happy happy with the results.

 

Step 4: Mosaic AI Fine-Tuning

Once we have all the synthetic and seed data for model training stored in Databricks, we can use the ML Experiments UI or the Databricks Foundational Model API to run fine-tuning.

  • You can validate the data to ensure you are following the right format. See Documentation: Prepare data for Foundation Model Fine-tuning
  • You can then create training run
    • To create training runs via API, use the create() function and use the Experiments section to create a training run via UI. Select the task to perform from Instruction Fine-tuning or Chat Completion.
    • Select the Foundation model to tune or train. For a list of supported models, see Supported models.

owbenita_1-1750228757796.png

  1. To view the run in the UI:
    1. Click Experiments in the left navigation bar to display the Experiments page.
    2. In the table, click the name of your experiment to display the experiment page. The experiment page lists all runs associated with the experiment.

owbenita_2-1750228772299.png

Step 5: Mosaic AI Evaluation 

Databricks provides a scalable framework for Agent Evaluation to evaluate the quality, cost, and latency of agentic AI applications, including RAG applications and chains. Alternatively, you can also add an evaluation dataset into the fine-tuning experiment in UI to evaluate while training for better performance.

owbenita_3-1750228784536.png

 

Conclusion

In this blog, we looked at an example of how you can generate synthetic data utilizing Databricks Model Serving and Serverless Batch Inference. We also shared key best practices such as ensuring structured outputs, decoupling question generation from response generation, and ensuring you have scoped the fine-tuned model to a predefined set of topics.

Synthetic data generation is a powerful and scalable solution for addressing challenges such as data scarcity, privacy concerns, and domain specificity when fine-tuning large language models (LLMs). By incorporating Databricks platform features, the process becomes even more robust, efficient, and impactful for organizations across industries.

Some key benefits of synthetic data generation on the Databricks platform include:

  • Accelerated Model Development and Customization: The Databricks platform enables efficient integration of synthetic data pipelines with Databricks' Delta Lake for seamless data management and MLflow for centralized tracking of experiments and model artifacts.
  • Improved Cost Efficiency: Databricks simplifies the fine-tuning process by caching base models in-cluster, significantly reducing startup times for large models (e.g., from over an hour to under two minutes). 
  • Advanced Synthetic Data Capabilities: Databricks provides tools like the Synthetic Data Generation API and integration with libraries such as SDV (Synthetic Data Vault) to expedite the creation of high-quality, diverse datasets. 

These features make Databricks a comprehensive platform for leveraging synthetic data to fine-tune LLMs, driving innovation while solving critical challenges uniquely associated with real-world datasets.