cancel
Showing results for 
Search instead for 
Did you mean: 
Technical Blog
Explore in-depth articles, tutorials, and insights on data analytics and machine learning in the Databricks Technical Blog. Stay updated on industry trends, best practices, and advanced techniques.
cancel
Showing results for 
Search instead for 
Did you mean: 
davidhuang
Databricks Employee
Databricks Employee

This is part 2 of a two-part series on Structured Extraction with LLM on Databricks. Read here for part 1!

 

Introduction

In part 1 of this series, I demonstrated how to use a large language model (LLM) with structured output and AI_QUERY to perform large-scale batch extraction. Using a Lease Contract dataset, I established a baseline performance with an F1-score of 71%.

Where can we go from here? We could explore additional prompt engineering techniques, such as incorporating few-shot examples or using tools like DSPy for prompt optimization. These are all worthwhile options. However, we will focus on customizing and refining the LLM itself to improve its performance on our specific task of structured extraction — and we’ll do this through fine-tuning.

In this follow-up blog post, I demonstrate how to improve the F1-score through LLM fine-tuning on Databricks. Using the training samples previously set aside, we generate additional synthetic samples with the LLM and fine-tune a smaller model to achieve better performance and accuracy than the baseline.

 

Synthetic data

I won’t go into detail here about the benefits or methods of fine-tuning on Databricks, as my colleagues have covered that thoroughly in a technical blog post. Instead, I want to delve deeper into synthetic data generation specifically for LLM fine-tuning.

As the name suggests, synthetic data generation involves creating artificial training data — a task at which LLMs excel. This approach has gained popularity for two key reasons: (1) LLMs are increasingly capable of producing high-quality synthetic data that closely resembles real-world data, and (2) certain open-source LLMs, like Meta’s Llama 3.1 herd of models, come with permissive licenses that allow their outputs to be used to improve other models.

Why generate synthetic data in the first place? Fine-tuning requires a large volume of training samples. Databricks recommends providing enough tokens to match at least one full context length of the model being fine-tuned (e.g., 130,000 tokens for Meta Llama 3.1 8B Instruct model).

In the previous post, we set aside 50 training samples with annotated labels. In the following sections, I’ll evaluate the results from fine-tuning without synthetic data, then with a doubled and tripled sample size by incorporating synthetic datasets, to see how each impacts the F1-score.

 

Fine-tuning with only the real training data

For this initial round, I fine-tuned a Meta Llama 3.1 8B Instruct model using only the 50 real training samples. First, we transform the training data into a chat format, where the prompt becomes a “user” message, and the annotated labels are presented as an “assistant” message. Next, we split this dataset into training and evaluation sets (e.g., an 80%-20% training-evaluation split).

This evaluation set is not the same hold-out dataset from part 1 of the blog post. During fine-tuning, the evaluation set will help monitor model improvement, while the hold-out set remains completely unused to avoid overfitting. This best practice helps ensure that the model remains generalizable and doesn’t become too tailored to the hold-out set.

Here’s the code for creating and splitting the initial training data:

ft_df = spark.sql(
    """
    SELECT 
    ARRAY(
        STRUCT('user' AS role, CONCAT('{0}', '\n', lease_doc) AS content),
        STRUCT('assistant' AS role, labels AS content)
    ) AS messages
    FROM dhuang.synthgen.lease_docs_train
    """.format(prompt.replace("'", '"'))
)

train_df, eval_df = ft_df.randomSplit([0.8, 0.2], seed=1)

train_df.write.format("delta").mode("overwrite").saveAsTable(
    "catalog_name.schema_name.ft_training_no_synths"
)

eval_df.write.format("delta").mode("overwrite").saveAsTable(
    "catalog_name.schema_name.ft_eval"
)

Here’s an example of a prompt to use as an instruction. You’ll notice it’s considerably shorter than the one used for structured extraction in part 1. With fine-tuning, extensive prompt engineering is less critical because we’re directly modifying the model’s underlying behavior through weight adjustments, reducing the need to allocate extra tokens to crafting detailed prompts.

prompt = """You are an AI assistant specialized in analyzing legal contracts. 
Your task is to extract relevant information from a given contract document to a structured JSON schema.

Instructions:
1. Carefully read the entire contract document provided at the end of this prompt.
2. Extract the relevant information.
3. Present your findings in JSON format as specified below.

Output Format:
- The output response should be valid JSON, enclosed in triple backticks with a JSON tag.
- The output values should be extracted from the contract document directly. They should not be paraphrased or otherwise edited.

Contract to analyze:
{lease}
"""

With the datasets ready, here’s how you can start fine-tuning with the Databricks Modal Training API:

import json
from databricks.model_training import foundation_model as fm

base_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
registered_model_name = "catalog_name.scheme_name.ft_model_llama3_1_8b_no_synth"
training_data_path = "catalog_name.scheme_name.ft_training"
eval_data_path = "catalog_name.scheme_name.ft_eval"
experiment_path = "/Users/user@email.com/synthgen-lease-doc-experiment"
task_type = "CHAT_COMPLETION"
training_duration = "5ep"
current_cluster_id = (
    json.loads(
        dbutils.notebook.entry_point.getDbutils().notebook().getContext().safeToJson()
    )["attributes"]["clusterId"]
)
lr_list = ["1e-6", "2e-6", "3e-6", "4e-6"] # for LR sweep

for learning_rate in lr_list:
    run = fm.create(
        data_prep_cluster_id=current_cluster_id,
        model=base_model_name,  
        train_data_path=training_data_path,
        eval_data_path=eval_data_path,
        task_type=task_type,
        training_duration=training_duration,
        register_to=registered_model_name,
        learning_rate=learning_rate,
        experiment_path=experiment_path
    )

In this fine-tuning run, we test a range of learning rates, training each for five epochs. This approach, known as “hyperparameter sweeping,” creates multiple models with different hyperparameter combinations. It can be made more complex by experimenting with a range of epochs as well. Here, four different models are fine-tuned, automatically logged with MLflow, and registered in Unity Catalog.

Once the run is complete, you can view the results in the MLflow Experiment tab. During fine-tuning, three metrics are recorded on your evaluation dataset: Cross-Entropy, Perplexity, and Token Accuracy scores. While these metrics help assess which model is “learning” more effectively, they don’t directly indicate which model will achieve the best F1-score — the primary metric we’re optimizing for in this case.

With the models ready, the next step is to deploy all four models and run batch inference on the hold-out set to obtain four different F1-scores for comparison. As a shortcut, I select the model with the highest Token Accuracy, which, in my experience, tends to align with better F1-score performance. However, this isn’t a guaranteed method for every use case, so calculating F1-scores for all models is ideal if you want to ensure the best choice.

Fine-tuning evaluation metricsFine-tuning evaluation metrics

 After deciding on a model, which at this point is registered in Unity Catalog, I deployed it on a Databricks Model Serving via a Provisioned Throughput (PT) endpoint. PT endpoints are available for supported models (base or fine-tuned), such as the Meta Llama 3.1 models.

Following the same batch inference process with AI_QUERY, as described in part 1, here are the results from fine-tuning with only the 50 real training data:

Round 1 resultsRound 1 results

We got about 70%. While this F1-score shows no improvement over the baseline score of 71%, it is worth noting that the fine-tuned model is a 8 billion parameter LLM, vs. the 70 billion parameter LLM for the baseline - almost 9 times smaller! This illustrates the main benefit of fine-tuning: you can produce a much smaller model (thus lower cost to serve) but at a similar quality and accuracy at performing a narrow domain-specific task.
Next, let's see if we can improve this metric with some synthetic data fine-tuning.

 

Generate synthetic data with LLM

LLMs, especially the larger and more capable ones, excel at “hallucinating” training data. Meta’s Llama 3.1 405B Instruct model is particularly effective for this due to its high capability and permissive open-source license, as mentioned earlier.

For structured extraction use cases, here are the steps:

  1. First, generate synthetic labels. You can point the LLM to look at a real set of labels, and simply ask it to create another one with different values.
  2. Then, point the LLM to the generated labels, along with a sample of a real lease document, and ask it to generate a new lease that looks like the real one, but with details given by the generated labels. This is very important, because your generated labels must match with the generated lease documents.
  3. Finally, ask the LLM again to evaluate both the generated labels and the generated lease documents to make sure they are generated correctly.

Here’s the query to generate synthetic labels and lease documents:

INPUT_TRAINING_TABLE = "catalog_name.schema_name.lease_docs_train"
OUTPUT_GENERATED_TABLE ="catalog_name.schema_name.generated_labels_and_leases"
ENDPOINT_NAME = "databricks-meta-llama-3-1-405b-instruct"

QUERY = f"""
create or replace table {OUTPUT_GENERATED_TABLE} as
select *, 
    ai_query(
        '{ENDPOINT_NAME}', 
        'You are an expert in generating JSON strings. 
        Given a JSON string, generate a new JSON string is similar but with different details.
        Important Notes:
        - The output response should be valid JSON, enclosed in triple backticks with a JSON tag.
        - Every key from the Input needs to be included; no more and no less.
        - Do not include the same key multiple times in the JSON.
        - Do not be verbose in your answer -- only respond with the correct format and the correct information.
        Input: ' || labels || 
        'Your output: ',
        modelParameters => named_struct('max_tokens', 4000 ,'temperature', 1.0)
    ) as generated_labels,
    ai_query(
        '{ENDPOINT_NAME}',
        'You are an expert in generating a lease contract. 
        Given a JSON string, generate a new lease contract that incorporates the details from the provided JSON.
        Important Notes:
        - The output should look similar to the provided sample lease document
        - Only change the details as specified in the input JSON.
        Your new lease contract should look like this: ' || lease_doc ||
        'Input: ' || generated_labels || 
        'Your output: ',
        modelParameters => named_struct('max_tokens', 4000 ,'temperature', 1.0)
    ) as generated_lease_doc
from {INPUT_TRAINING_TABLE}
;
"""

synth_generation = spark.sql(QUERY)
display(synth_generation)

display(spark.sql(f"select * from {OUTPUT_GENERATED_TABLE};"))

Here are the results:

Synthetic labels and lease documentsSynthetic labels and lease documents

 Next, use the same process to evaluate the synthetic data. Here’s an example code to do that.

INPUT_GENERATED_TABLE = "catalog_name.schema_name.generated_labels_and_leases"
OUTPUT_EVALUATED_TABLE = "catalog_name.schema_name.evaluated_synth_labels_and_leases"
ENDPOINT_NAME = "databricks-meta-llama-3-1-405b-instruct"

QUERY = f"""
create or replace table {OUTPUT_EVALUATED_TABLE} as
select *, 
    ai_query(
        '{ENDPOINT_NAME}', 
        'You are an expert in scoring a JSON object. 
        Given a JSON string, you are to score it based on the following criteria:
        - The JSON string must have the follow keys: 
          "end_date", "leased_space" ,"lessee", "lessor", "signing_date", 
          "start_date", "term_of_payment", "designated_use", 
          "extension_period", "expiration_date_of_lease" 
        - If any one of the keys are missing, simply return "0", otherwise return "1"
        - Do not offer any explanation.
        - *VERY IMPORTANT:* DO NOT EXPLAIN, ONLY RETURN "0" OR "1".
        Score the following JSON object: ' || generated_labels
    ) as eval_labels,
    ai_query(
        '{ENDPOINT_NAME}',
        'You are an expert in scoring a lease contract according to a given JSON object. 
        Give a JSON object along with a lease contract, you are to score it based on the following criteria:
        - All details from the given JSON object must be present in the given lease contract.
        - If any one of the keys are missing in the given lease contract, simply return "0", otherwise return "1"
        - Do not offer any explanation.
        - *VERY IMPORTANT:* DO NOT EXPLAIN, ONLY RETURN "0" OR "1".
        The JSON object to refer to: ' || generated_labels ||
        '\nScore the following lease contract: ' || generated_lease_doc
    ) as eval_leases
from {INPUT_GENERATED_TABLE}
;
"""

synth_evaluation = spark.sql(QUERY)
display(synth_evaluation)

display(spark.sql(
    f"""
    select * from {OUTPUT_EVALUATED_TABLE}
    where eval_labels = '1' AND eval_leases = '1'
    ;
    """
))

The results look like this:

Evaluation of synthetic dataEvaluation of synthetic data

With the generated scores, you can filter out samples that scored “0” on either of the columns. In this process, there were about 7 generated samples that failed the evaluation criteria. This process can be repeated as many times as you want. It might be more beneficial to generate a small batch of synthetic dataset first, and test to see how much it can improve on your fine-tuned model. Doing this iteratively can help ensure you’re not generating more samples than you need. After all, generating synthetic samples is not free!

INSERT INTO catalog_name.schema_name.lease_docs_train_synth 
    (lease_id, lease_doc, labels)
SELECT lease_id, 
    generated_lease_doc AS lease_doc, 
    replace(replace(generated_labels, '```json', ''), '```', '') as labels
FROM catalog_name.schema_name.evaluated_synth_labels_and_leases
WHERE eval_labels = '1' AND eval_leases = '1'
;

Finally, let’s see if these generated samples, along with the real training samples, improves the F1-score.

 

Fine-tuning with real and synthetic samples

With the generated samples, we can union it with our real training dataset.

ft_df = spark.sql(
    """
    -- real training data
    SELECT 
    ARRAY(
        STRUCT('user' AS role, CONCAT('{0}', '\n', lease_doc) AS content),
        STRUCT('assistant' AS role, labels AS content)
    ) AS messages
    FROM dhuang.synthgen.lease_docs_train

    -- union synthetic data
    UNION
    SELECT 
    ARRAY(
        STRUCT('user' AS role, CONCAT('{0}', '\n', lease_doc) AS content),
        STRUCT('assistant' AS role, labels AS content)
    ) AS messages
    FROM dhuang.synthgen.lease_docs_train_synth
    ;
    """.format(prompt.replace("'", '"'))
)

train_df, eval_df = ft_df.randomSplit([0.8, 0.2], seed=1)

train_df.write.format("delta").mode("overwrite").saveAsTable(
    "catalog_name.schema_name.ft_training_with_synths"
)

eval_df.write.format("delta").mode("overwrite").saveAsTable(
    "catalog_name.schema_name.ft_eval"
)

Following the same fine-tuning and batch inference process described earlier in the blogpost, here are the results:

Round 2 resultsRound 2 results

 As you can see, the F1-score improves by about 5%!

You can repeatedly generate and fine-tune with more synthetic datasets to see how much improvement you can squeeze out of this. In fact, in this example I generated 1 more round of synthetic samples. With a total of ~130 samples (50 real, 43 from the first round of synthetic generation, 38 from the second round), here are the results:

Round 3 resultsRound 3 results

The F1-score increased by nearly another 5%! Here’s the side-by-side F1-score results.

Side-by-side comparisonSide-by-side comparison

I can’t say for sure if generating another round of synthetic data would improve the F1-score by a similar margin, since there could be a diminishing return. But, as illustrated by this dataset, this is a worthwhile approach to improve the accuracy of your structured extraction use case, especially when you don’t have many human-annotated samples to use for fine-tuning.