In the era of generative AI and large language models (LLMs), utilizing domain and business data has become crucial for tailoring models to achieve differentiation and enhance personalization. For more advanced customizations, Foundational Model Fine-tuning is a popular choice when using a base pretrained foundational model with prompt engineering or RAG is not providing the results needed, you want to retain maintain complete control over the fine-tuned model and its deploy
To help understand how this works, we will walk through an example scenario using a fictional bank that is improving customer experience with an AI-powered chatbot. They need to ensure that they are providing financial guidance without giving direct investment advice.
One significant challenge they experienced relying on prompt-engineering techniques, including those used in RAG, is that the system prompts became prohibitively long to incorporate all company domain rules and tone-of-voice consistently, thus significantly increasing the cost of invoking the model.
First, curate a qualty and representative dataset of what your users may ask and scope the model to a specific set of specialized topics, including in-scope topics and out-of-scope topics. The goal is to ensure the model is able to respond appropriately even on topics that are deemed risky or inappropriate. For example:
ment, or you need to improve the overall price-performance by using a smaller specialized model.
Fine-tuning is a process that uses supervised learning to train a pretrained Large Language Model (LLM) to improve the performance and quality of the responses for a specific use case. Supervised fine-tuning can be data-intensive, often requiring thousands of examples to significantly specialize a model for complex tasks. However, manually building a good quality, diverse, and relevant dataset for fine-tuning is often a resource-intensive task and a challenging effort for many organisations. Databricks supports both Chat completion and Instruction fine-tuning
One effective strategy for fine-tuning involves using an LLM as a "teacher" to produce synthetic data, which is then used to fine-tune a smaller "student" model. For example, Meta's Llama 3.1- 405B, a state-of-the-art LLM, can generate synthetic datasets for fine-tuning smaller models like Llama 3.1- 70B. This process, known as knowledge distillation, transfers the teacher model's expertise to the student model, making it more efficient and task-specific.
In this blog, we will walk through approaches and best practices for generating synthetic data for fine-tuning and how Databricks helps to simplify this process, including Batch Inference on Serverless Model Serving. Databricks supports three types of fine-tuning tasks: Chat Completion, Instruction fine-tuning, and continued pre-training; learn more about these different tasks in the documentation. We will focus on an example of a fictional bank, specifically for chat completion fine-tuning.
The architecture below represents an end-to-end solution that comprises three main steps:
To help understand how this works, we will walk through an example scenario using a fictional bank that is improving customer experience with an AI-powered chatbot. They need to ensure that they are providing financial guidance without giving direct investment advice.
One significant challenge they experienced relying on prompt-engineering techniques, including those used in RAG, is that the system prompts became prohibitively long to incorporate all company domain rules and tone-of-voice consistently, thus significantly increasing the cost of invoking the model.
First, curate a quality and representative dataset of what your users may ask and scope the model to a specific set of specialized topics, including in-scope topics and out-of-scope topics. The goal is to ensure the model is able to respond appropriately even on topics that are deemed risky or inappropriate. For example:
|
topic |
question |
in_scope |
|
Account Management |
How do I set up a direct deposit? |
Yes |
|
Investment Advice |
Should I buy Apple stock? |
No |
In this step, we provide a well-written seeded prompt for data generation. One best practice will be to use a prompt structure that encourages the model to engage in multi-stage reasoning, such as thinking step-by-step or considering various constraints sequentially to generate comprehensive and relevant responses.
For instance, the prompt might instruct the teacher LLM to first identify the user's intent, then recall relevant bank policies related to that intent, and finally, formulate a response that is helpful, compliant, and maintains the bank's tone-of-voice.
QUESTION_PROMPT = {
"system": """
You are a knowledgable assistant for a banking app. Your role is to generate relevant questions inspired by provided examples. To ensure the questions are well structured and diverse, follow step-by-step reasoning before generating final output Format your thought process clearly, separating analysis from final outputs""",
"user": """
Generate {n} new questions for our banking app, follow these steps:
1. First analyze the example question and category
- Category: {example_category}
- Question: {example_question}
2. List potential subtopics
3. Consider how to frame them as relevant questions
4. Evaluate if the topics are diverse
5. Keep questions short and simplistic
6. Show your thought process before generating the final output
"""
}
Prompt expansion refers to the process of dynamically evaluating and substituting variables, commands, or special sequences in shell prompts or text. Once we have a quality prompt, we need to ensure that the outputs are structured to match our table schema. To make this task easier and flexible, we use the open–source library for data validation Pydantic BaseModel, to help generate a structured object as we call our LLM.
from pydantic import BaseModel
from openai import OpenAI
#Prepares a list of chat messages in the correct format needed for LLMs
def get_chat_messages(
prompt: dict[str, str],
user_prompt_args: dict[str, str]
) -> list[dict[str, str]]:
messages = []
user_prompt = prompt["user"].format(**user_prompt_args)
messages.append({"role": "user", "content": user_prompt})
return messages
#Generates a structured output from the LLM using Pydantic models for validation
def generate_structured_object(
prompt: dict[str,str],
prompt_args: dict[str,any] | None,
base_model: type[BaseModel],
model: str,
credentials: dict[str, dict[str,str]]| None,
repeat_count: int = 2
):
messages = get_chat_messages(prompt, prompt_args)
_creds = credentials[model_name]
client = OpenAI( #change to call databricks llama model
api_key = _creds["api_key"],
base_url = _creds["base_url"]
)
for _ in range(repeat_count):
try:
completion = client.beta.chat.completions.parse(
model = model,
messages=messages,
response_format=base_model
)
return completion.choices[0].message.parsed
except Exception as e:
print(f"Error: {e}")
We can then create a function to pass the arguments to our function call to a model served on Databricks. For this task, we recommend using a scale-to-zero enabled provisioned throughput endpoint to avoid hitting any throughput limits. Provisioned throughput endpoints help to ensure optimized inference with performance guarantees for production workloads. However, for development purposes, pay-per-token is a more cost-effective solution
class PromptModel(BaseModel):
reasoning: str
questions: list[str]
# Calls the generarate_structured_object and passes required arguments
def generate_one_prompt(
example_question: str,
example_category: str,
in_scope: str,
credentials: dict[str, dict[str,str]]| None,
model: str,
num_questions: int = 1,
prompt: dict[str,str] | None = None
) -> list[str]:
prompt_args = {}
prompt_args["n"] = num_questions
prompt_args["example_question"] = example_question
prompt_args["example_category"] = example_category
prompt_args["in_scope"] = in_scope
response: PromptModel = generate_structured_object(
prompt = prompt,
prompt_args = prompt_args,
base_model = PromptModel,
model = model,
credentials = credentials
)
questions = [q for q in response.questions if q and len(q.strip()) > 0]
return questions
The below function refers to a higher-level logic that will take a batch of inputs and distribute the calls to the LLM endpoint concurrently. This is essential for speeding up the generation of a large dataset.
from multiprocessing.pool import ThreadPool
def apply_function_parallel(
func: callable, data: list[dict[str, any]], n_processes
) -> list[dict[str, any]]:
with ThreadPool(processes=n_processes) as pool:
result = pool.map(func, data)
return result
def generate_n_prompts(
records: list[dict[str, any]],credentials: dict[str, dict[str, str]],model: str = "databricks-meta-llama-3-3-405b-instruct",number_of_questions: int, n_processes: int
)-> list[dict[str,any]]:
def _process(record: dict[str, any]) -> list[dict[str, any]]:
questions = generate_one_prompt(record["Message"],
record["Topic"], credentials, model, num_questions = num_questions,prompt = QUESTION_PROMPT)
output = [{"question":q} for q in questions]
return output
return list( itertools.chain(*apply_function_parallel(_process, records, n_processes)))
)
Finally, we simply use Apache Spark APIs to read from the delta table that holds our sample questions and then call the function to generate prompts in parallel, further writing the results back to a delta table.
from pyspark.sql import DataFrame
import pandas as pd
input_table_path = '<your_catalog.your_schema.seed_examples>'
output_table_path = '<your_catalog.your_schema.synth_examples>'
serving_endpoint = "<your_endpoint>"
api_token = "<your_api_token>"
model_name = "<model_name"
num_questions = 10
processes = 2
generation_kwards = {"temperature": 0.99, "max_tokens": 200 }
#read data from Spark
dataframe = spark.sql(f"SELECT * FROM {input_table_path}")
records = dataframe.toPandas().to_dict(orient="records")
question_dataframe = generate_n_prompts(
records= records,
model = model_name,
num_questions= num_questions,
n_processes = processes,
credentials={
f"{model_name}": {
"base_url": serving_endpoint,
"api_key": api_token,
}
},
)
display(question_dataframe)
final_df = spark.createDataFrame(pd.DataFrame(question_dataframe))
Now that we have the questions and prompts, we will use Databricks AI functions for batch inference. This simplifies a number of tasks we performed above, such as creating structured outputs, schema validation, and parallel processing.
We recommend decoupling the generation of prompts from the responses, as this will help the model to focus on what kind of questions and tasks it needs to perform at each step for better quality. It also allows you to tune each part of the pipeline independently to achieve more realistic outputs.
In this step, we create an AI query to produce AI generated responses in parallel across the synthetic questions within our table.
synth_responses = "<your_catalog.your_schema.synth_responses>"
QUERY = f""
CREATE OR REPLACE TABLE {synth_responses}
AS
SELECT
question,
ai_query(
'{model_name}',
'You are a knowledgable expert for a banking app. Your role is to generate a well-formed response to the provided question. To ensure the answers are safe and adhere to the banking apps policies and tone-of-voice, follow the following guidelines:
{rules_text}
##Reasoning
Use step-by-step reasoning before generating final output. Format your thought process clearly, separating analysis from final outputs
Generate a response for our banking app, follow these steps:
1. First analyze the question and what the user is asking for
Input:' || question ||
'
2. Analyze if the question is in scope for the banking app
3. Generate an initial question that follows each rule
4. Analyze if your response is safe and adheres to each of the rules
5. Generate a final response, modifying your response as needed
5. Keep responses straight to the point
7. Only generate final answer
',
responseFormat => 'STRUCT<content:STRUCT<content:STRING>>',
modelParameters => named_struct('temperature', 0.7, 'max_tokens', 512)
) AS response
FROM {output_table_path}
""
synth_generation = spark.sql(QUERY)
For chat completion fine-tuning, chat-formatted data must be in a .jsonl file, where each line is a separate JSON object representing a single chat session. Each chat session is represented as a JSON object with a single key, messages, that maps to an array of message objects.
We create a table that represents the format required for chat completion inside a delta table;
CREATE OR REPLACE TABLE <your_catalog.your_schema.formatted_synth_responses>
AS
SELECT to_json(named_struct(
'messages', array(
named_struct('role', 'user', 'content', question),
named_struct('role', ' assistant', 'content', content)
)
)) AS messages
FROM (
SELECT question, from_json(response, 'STRUCT<content:STRING>').content AS content
FROM owbenita_catalog.synth_blog_dataset.synth_responses
)
We can then store this as a .jsonl file in a databricks Volume once we are happy happy with the results.
Once we have all the synthetic and seed data for model training stored in Databricks, we can use the ML Experiments UI or the Databricks Foundational Model API to run fine-tuning.
Databricks provides a scalable framework for Agent Evaluation to evaluate the quality, cost, and latency of agentic AI applications, including RAG applications and chains. Alternatively, you can also add an evaluation dataset into the fine-tuning experiment in UI to evaluate while training for better performance.
In this blog, we looked at an example of how you can generate synthetic data utilizing Databricks Model Serving and Serverless Batch Inference. We also shared key best practices such as ensuring structured outputs, decoupling question generation from response generation, and ensuring you have scoped the fine-tuned model to a predefined set of topics.
Synthetic data generation is a powerful and scalable solution for addressing challenges such as data scarcity, privacy concerns, and domain specificity when fine-tuning large language models (LLMs). By incorporating Databricks platform features, the process becomes even more robust, efficient, and impactful for organizations across industries.
Some key benefits of synthetic data generation on the Databricks platform include:
These features make Databricks a comprehensive platform for leveraging synthetic data to fine-tune LLMs, driving innovation while solving critical challenges uniquely associated with real-world datasets.
You must be a registered user to add a comment. If you've already registered, sign in. Otherwise, register and sign in.