cancel
Showing results for 
Search instead for 
Did you mean: 
Technical Blog
Explore in-depth articles, tutorials, and insights on data analytics and machine learning in the Databricks Technical Blog. Stay updated on industry trends, best practices, and advanced techniques.
cancel
Showing results for 
Search instead for 
Did you mean: 
davidhuang
Databricks Employee
Databricks Employee

This is part 1 of a two-part series on Structured Extraction with LLM on Databricks. Read here for part 2!

 

What is structured extraction?

Structured extraction, sometimes referred to as “key information extraction,” “entity extraction,” or simply as “text-to-JSON,” is a process that transforms unstructured text into a structured format, such as JSON, making it easily accessible for further processing, analysis, or storage. With the rise of large language models (LLMs), this task can now be accomplished efficiently and affordably, enabling enterprises to unlock valuable insights from large volumes of unstructured data, including PDFs, text files, and scanned documents.

Here are some examples of real life applications of structured extraction in different industries:

  • Financial Services: In loan underwriting and credit scoring, extracting key financial details (e.g., income, debt, assets) from bank statements, tax forms, and payslips to help streamline underwriting and assess creditworthiness accurately.
  • Healthcare: Extracting patient information from medical records, lab reports, and radiology notes to populate Electronic Health Records (EHR) reduces manual errors and improves data accessibility.
  • E-commerce: Sentiment and trend analysis involves extracting product sentiment, common complaints, and frequently mentioned features to understand product performance and guide development.

With Databricks’ Mosaic AI platform, the entire structured extraction process can be completed end-to-end without the need to transfer data or combine tools from different providers.

In this two-part tutorial, I first demonstrate Databricks features such as Databricks Foundational Model API for structured extraction, and Databricks SQL with AI_QUERY for high-performance batch inference.

Then, in Part 2, I use the Databricks Model Training API to fine-tuning a smaller model that outperforms our baseline performance, and use the Databricks Model Serving API to deploy the fine-tuned model for batch inference. In addition, I demonstrate how you can create synthetic training data to help with LLM fine-tuning when you don’t have enough human-annotated data.

Note: Some of the Databricks functionalities in this tutorial are currently in Public Preview, which means it can be at different levels of product maturity and production-readiness. It’s best to consult your Databricks account team before using any of these features for production use cases.

 

Dataset for structured extraction

For this tutorial, we work with a lease contract dataset from the paper “A Benchmark for Lease Contract Review” (Leivaditi et al., 2020). To begin, we have already ingested the lease documents as a Delta table in Unity Catalog. If you’d like to see some resources on how to do unstructured document ingestion on Databricks, I suggest starting with this one.

Here’s an example of the dataset: The lease_id column is the primary key. The lease_doc column is the raw OCR’d text from the lease documents. The labels column are manual annotations, which are the expected structured outputs.

Dataset exampleDataset example

Let’s use a LLM to extract from the lease_doc column an output that is similar to the annotated labels.

 

Structured extraction with a LLM

Structured extraction can be done using prompt engineering on powerful LLMs such as OpenAI’s GPT-4o model, Anthropic’s Claude 3.5 Sonnet model, or Meta’s Llama family of models.

The process is quite simple. First, you write a prompt specifying the fields you want to extract, and then attach with it the lease contract’s raw texts. Like this:

 

prompt = """You are an AI assistant specialized in analyzing legal contracts. 
Your task is to extract relevant information from a given contract document. 
Your output must be a structured JSON object.

Instructions:
1. Carefully read the entire contract document provided at the end of this prompt.
2. Extract the relevant information.
3. Present your findings in JSON format as specified below.

Important Notes:
- Extract only relevant information. 
- Consider the context of the entire contract when determining relevance.
- Do not be verbose, only respond with the correct format and information.
- Some docs may have multiple relevant excerpts -- include all that apply.
- Some questions may have no relevant excerpts -- just return ["N/A"].
- Do not include additional JSON keys beyond the ones listed here.
- Do not include the same key multiple times in the JSON.

Expected JSON keys and explanation of what they are:
- 'end_date': The end date of the lease.
- 'leased_space': Description of the space that is being leased.
- 'lessee': The lessee's name (and possibly address).
- 'lessor': The lessor's name (and possibly address).
- 'signing_date': The date the contract was signed.
- 'start_date': The start date of the lease.
- 'term_of_payment': Description of the payment terms.
- 'designated_use': Description of the designated use of the property being leased.
- 'extension_period': Description of the extension options for the lease.
- 'expiration_date_of_lease': The expiration data of the lease.

Contract to analyze: 
{lease}
"""

 

Here, I’m using the Databricks Foundation Model API to call on Meta’s Llama 3.1 70B Instruct model. This method is compatible with the OpenAI API, but you will need to specify your Databricks workspace URL and your access token. In addition, you can use a wide range of different model endpoints, such as supported pay-per-token endpoints, provisioned throughput endpoints, supported proprietary closed models, or any deployed custom models.

Regardless of which type of endpoint you want to use, you can invoke it via an API call. This service provides an easy way to deploy, manage, and query AI models for real-time and batch inference via a unified REST API… and completely serverless!

 

import openai
from mlflow.utils.databricks_utils import get_databricks_host_creds

creds = get_databricks_host_creds("databricks")

client = openai.OpenAI(
    api_key=creds.token, 
    base_url=f"{creds.host}/serving-endpoints"
)

chat_completion = client.chat.completions.create(
    messages=[
        {
            "role": "user", 
            "content": prompt.format(lease=example_lease)
        }
    ],
    model="databricks-meta-llama-3-1-70b-instruct",
    max_tokens=500,
)

print(chat_completion.choices[0].message.content)

 

When you’ve decided on an endpoint, simply pass in the prompt with the lease doc, and run the cell. Here’s the result:

Raw LLM outputRaw LLM output

And here’s the corresponding labels:

Actual human-annotated labelsActual human-annotated labels

At first glance, the result looks right. However, when comparing to the true labels, you can see there are two issues with the output format:

  1. The LLM’s output is overly verbose, which makes it an invalid JSON object. You’d need to do some regex-based post-processing to preserve only the JSON part of the generated string.
  2. The generated output did not match the exact JSON schema.

There are a number of prompt engineering practices you can deploy here to get better results, such as including few-shot examples. However, the simplest way to improve accuracy in this scenario is through using the “structured output” feature, which is now available on Databricks.

 

Structured output with response_format

Databricks now supports structured outputs on the Foundation Model API with supported models, such as the 70B and the 405B variants from the Meta Llama 3.1 model family. This feature is available for any LLM model available in Foundation Models API.

To do this, simply write a JSON schema, which is then passed into the LLM call as an additional argument. This method is flexible because you can write whatever JSON schema that fits your use case.

Here’s my JSON schema:

 

response_format = {
    "type": "json_schema",
    "json_schema": {
        "name": "lease_contract_extractions",
        "schema": {
            "type": "object",
            "properties": {
                "end_date": {
                    "type": "array",
                    "items": {"type": "string"},
                    "description": "The end date of the lease.",
                },
                "leased_space": {
                    "type": "array",
                    "items": {"type": "string"},
                    "description": "Description of the space that is being leased.",
                },
                "lessee": {
                    "type": "array",
                    "items": {"type": "string"},
                    "description": "The lessee's name (and possibly address).",
                },
                "lessor": {
                    "type": "array",
                    "items": {"type": "string"},
                    "description": "The lessor's name (and possibly address).",
                },
                "signing_date": {
                    "type": "array",
                    "items": {"type": "string"},
                    "description": "The date the contract was signed.",
                },
                "start_date": {
                    "type": "array",
                    "items": {"type": "string"},
                    "description": "The start date of the lease.",
                },
                "term_of_payment": {
                    "type": "array",
                    "items": {"type": "string"},
                    "description": "Description of the payment terms.",
                },
                "designated_use": {
                    "type": "array",
                    "items": {"type": "string"},
                    "description": "Designated use of the property being leased.",
                },
                "extension_period": {
                    "type": "array",
                    "items": {"type": "string"},
                    "description": "Description of the extension options for the lease.",
                },
                "expiration_date_of_lease": {
                    "type": "array",
                    "items": {"type": "string"},
                    "description": "The expiration date of the lease.",
                },
            },
            "required": [
                "end_date",
                "leased_space",
                "lessee",
                "lessor",
                "signing_date",
                "start_date",
                "term_of_payment",
                "designated_use",
                "extension_period",
                "expiration_date_of_lease",
            ],
        },
        "strict": True,
    },
}

import json
import openai
from mlflow.utils.databricks_utils import get_databricks_host_creds

creds = get_databricks_host_creds("databricks")

client = openai.OpenAI(
    api_key=creds.token, 
    base_url=f"{creds.host}/serving-endpoints"
)

chat_completion = client.chat.completions.create(
    messages=[
        {
            "role": "user", 
            "content": prompt.format(lease=example_lease)
        }
    ],
    model="databricks-meta-llama-3-1-70b-instruct",
    response_format=response_format # structured output argument
    max_tokens=500,
)

json.loads(chat_completions.choices[0].message.content)

 

Notice the additional argument response_format here. When passing the JSON schema written above into this argument, the LLM's response will be constrained to adhere to your desired response format. Here's the result:

Structured LLM outputStructured LLM output

As you can see, the LLM’s output is now a valid JSON object that looks closer to the labels.

 

Batch inference and evaluation

Now, how can you determine if the extracted fields are accurate and complete in a scalable way? Let’s say you have thousands of these lease documents, but you have only 100 samples with true labels. You’d want to perform a baseline evaluation of this approach before applying it to the rest of your documents. Let’s see how you can do that on Databricks.

Before we continue, let’s split the dataset into hold-out and training sets. Ignore the training set for now. We will use that dataset in part 2 of this tutorial.

 

df_train, df_holdout = df.randomSplit([0.50, 0.50], seed=614)

df_holdout.write.format("delta").mode("overwrite").saveAsTable(
    "catalog_name.schema_name.lease_docs_holdout"
)

df_train.write.format("delta").mode("overwrite").saveAsTable(
    "catalog_name.schema_name.lease_docs_train"
)

 

To do LLM batch inference on Databricks, here are some options:

  1. Instantiate a Python class or function that calls a LLM model endpoint, and wrap it in a Pandas user-defined function (UDF), which then can be applied on a Spark dataframe in a distributed manner.
  2. Use the powerful AI_QUERY function, which was optimized for LLM batch inference. This is currently the recommended approach by Databricks.

Let’s go through each of these approaches.

 

Batch inference with Pandas UDF

In this approach, you are leveraging a Databricks cluster to send concurrent requests to the LLM model endpoint via a Pandas UDF. This can be done by wrapping a simple function that calls the endpoint, but I added some error handling so that the entire batch inference job does not error out if due to a single failed request.

 

import openai
from mlflow.utils.databricks_utils import get_databricks_host_creds
creds = get_databricks_host_creds("databricks")

class Extractor:
    def __init__(
        self, creds, endpoint_name, prompt, temperature, num_output_tokens
    ):
        self.client = openai.OpenAI(
            api_key=creds.token,
            base_url=f"{creds.host}/serving-endpoints",
            timeout=300,
            max_retries=3,
        )
        self.endpoint = endpoint_name
        self.prompt = prompt
        self.temperature = temperature
        self.num_output_tokens = num_output_tokens

    def predict(self, text):
        try:
            response = self.client.chat.completions.create(
                model=self.endpoint,
                messages=[
                    {
                        "role": "user", 
                        "content": self.prompt.format(lease=text)
                    }
                ],
                temperature=float(temperature),
                max_tokens=int(num_output_tokens),
                response_format=response_format,
            )        
            return (
                response.choices[0].message.content,
                response.usage.completion_tokens,
                response.usage.prompt_tokens,
                response.usage.total_tokens,
                None
            )

        except Exception as e:
            return None, 0, 0, 0, str(e)

 

Now, wrap it in a Pandas UDF so that when executed, it is distributed across the available worker cores in your Databricks cluster.

 

import pandas as pd
from pyspark.sql.functions import pandas_udf
from typing import Iterator

@pandas_udf(
    "output string, completion_tokens int, prompt_tokens int, 
    total_tokens int, error string"
)
def extract_udf(content_batches: Iterator[pd.Series]
    ) -> Iterator[pd.DataFrame]:
    client = Extractor(
        creds=creds,
        endpoint_name="databricks-meta-llama-3-1-70b-instruct",
        prompt=prompt,
        temperature=0.0,
        num_output_tokens=4000,
    )
    for content_batch in content_batches:
        yield (
            pd.DataFrame.from_records(content_batch.apply(client.predict)
        )

 

Finally, simply apply the function using PySpark on your hold-out dataset. 

 

from pyspark.sql.functions import col

holdout_df = holdout_df.repartition(25)

extracted_df = (
    holdout_df
    .withColumn("extraction", extract_udf(col("lease_doc")))
    .selectExpr(
        "*",
        "extraction.output as output",
        "extraction.completion_tokens as completion_tokens",
        "extraction.prompt_tokens as prompt_tokens",
        "extraction.total_tokens as total_tokens",
        "extraction.error as error",
    )
    .drop("extraction")
    .write.mode("overwrite")
    .saveAsTable("catalog_name.schema_name.lease_extracted_llama70b")
)

 

The speed of the batch inference job depends on the size of your dataset, total number of input and output tokens, the number of your Databricks cluster workers, and the concurrency limit of your LLM model endpoint. 

Here’s the result:

Batch structured outputsBatch structured outputs

Under this approach, you can see there’s a lot of code that needs to be written. Not to mention the effort it takes to configure the Databricks cluster. There are a few knobs you can tune to ensure you’re not over- or under-utilizing the Model Serving endpoint, such as:

  • Partitioning your data in a way that it can be distributed evenly across your cluster’s workers, or
  • Setting the SPARK_WORKER_CORES environment variable on your cluster to try and match with your LLM model endpoint’s maximum concurrency. The idea here is to over-subscribe your CPUs with “virtual” cores, effectively pushing your cluster to do more parallel tasks. This is an iterative test to find the optimal value.

That’s pretty complex! And you still might not properly utilize the maximum throughput of your LLM model endpoint. To address this complexity, Databricks has recently made significant optimization and improvement to the AI_QUERY functionality. Let’s test it out next.

 

Batch inference with AI_QUERY

AI_QUERY is a part of Databricks SQL’s AI Functions. You can simply use the function in a Databricks SQL query on a Delta Table, point to a Model Serving endpoint with a prompt, and it will perform the LLM inference in a highly performant and distributed manner.

The syntax is super simple. It’s a function that can be applied directly to a Delta table. For our use case, you can use the SQL Editor on Databricks, with the prompt I used earlier in this post. Like this:

 

%sql
create or replace table catalog_name.schema_name.lease_extracted_llama70b as
select *,
  ai_query(
    "databricks-meta-llama-3-1-70b-instruct",
    "You are an AI assistant specialized in analyzing legal contracts. 
     Your task is to extract relevant information from a given contract document. 
     Your output must be a structured JSON object.

     Instructions:
     1. Carefully read the entire contract document provided at the end of this prompt.
     2. Extract the relevant information.
     3. Present your findings in JSON format as specified below.

     Important Notes:
     - Extract only relevant information. 
     - Consider the context of the entire contract when determining relevance.
     - Do not be verbose, only respond with the correct format and information.
     - Some docs may have multiple relevant excerpts. Include all that apply.
     - Some questions may have no relevant excerpts. Do not include the JSON key at all.
     - Do not include additional JSON keys beyond the ones listed here.
     - Do not include the same key multiple times in the JSON.
     - If a field is missing, return ['N/A'] for that field.

     Expected JSON keys and explanation of what they are:
     - 'end_date': The end date of the lease.
     - 'leased_space': Description of the space that is being leased.
     - 'lessee': The lessee's name (and possibly address).
     - 'lessor': The lessor's name (and possibly address).
     - 'signing_date': The date the contract was signed.
     - 'start_date': The start date of the lease.
     - 'term_of_payment': Description of the payment terms.
     - 'designated_use': Description of the designated use of the property being leased.
     - 'extension_period': Description of the extension options for the lease.
     - 'expiration_date_of_lease': The expiration data of the lease.

     Contract to analyze: " || lease_doc
  ) as output
from catalog_name.schema_name.lease_docs_holdout

 

This creates a new Delta table in Unity Catalog with LLM-generated results. Using AI_QUERY removes the need for cluster configuration to manage concurrency, as it automatically saturates your chosen endpoint to optimize usage. Which is also why Databricks recommends using AI_QUERY with a Provisioned Throughput endpoint for faster performance. Additionally, AI_QUERY has built-in fault tolerance with automatic retries and error handling.

Note: At the time of writing, response_format is not supported with AI_QUERY, which means you may need to do some regex-based post-processing to ensure the output is in a valid JSON format.

Now, with the batch inference output generated, we’re ready to run evaluation to get to a baseline metric.

 

Evaluating LLM-generated structured outputs

There are a number of ways to evaluate LLM-generated structured outputs. Since we have an evaluation dataset, we can use traditional ML classification evaluation metrics such as Precision, Recall, and F1-score. We will use fuzzy-matching here since LLM can be verbose while still producing correct results.

First, we transform the LLM-generated results table so that each field from each sample is compared to its corresponding annotated sample. To do this, we first define a PySpark dataframe schema.

 

from pyspark.sql.types import (
    StructType, StructField, StringType, ArrayType
)

fields = [
    "end_date", "leased_space", "lessee", "lessor", "signing_date", 
    "start_date", "term_of_payment", "designated_use", "extension_period", 
    "expiration_date_of_lease",
]

schema = StructType(
    [StructField(field, ArrayType(StringType()), True) for field in fields]
)

 

Then, create a melted Pandas dataframe for the LLM-generated results, and one for the labels.

 

import pandas as pd
from pyspark.sql.functions import from_json, col

def create_melted_df(df, column_name, schema):
    melted_df = (
        df.select("lease_id", column_name)
        .withColumn(column_name, from_json(col(column_name), schema))
        .select("lease_id", f"{column_name}.*")
        .toPandas()
    )
    return pd.melt(
        melted_df, 
        id_vars=["lease_id"], 
        var_name="variable", 
        value_name=column_name
    )

# create labels df
labels_melted_df = create_melted_df(llama70b_df, "labels", schema)

# create outputs df
output_melted_df = create_melted_df(llama70b_df, "output", schema)

 

Finally, combine both Pandas dataframes together so that the labels and the generated outputs are side-by-side. 

 

# combine both dfs
comp_df = (
    pd.merge(
        labels_melted_df,
        output_melted_df,
        left_on=["lease_id", "variable"],
        right_on=["lease_id", "variable"],
        how="left"
    )
    .sort_values(by="lease_id")
)

# clean up nulls and unpack arrays
comp_df['labels'] = comp_df['labels'].apply(
    lambda x: ', '.join(x) if x is not None else 'N/A'
)

comp_df['output'] = comp_df['output'].apply(
    lambda x: ', '.join(x) if x is not None else 'N/A'
)

 

The results should look like this:

Side-by-side comparisonSide-by-side comparison

Now with the dataframe processed and transformed, let’s define the evaluation function to calculate the desired metrics. As mentioned before, fuzzy-matching is more appropriate than exact-matching in this use case. Since in natural language, two sentences do not need to be exactly the same to be semantically the same. 

Here, we use difflib to compare two string sequences, and as long as 80% or more of it is the same then we’ll call it a match. After that, calculate and return the metrics.

 

import difflib

def is_null_like(value):
    # Helper function to determine if a value is null-like.
    return value is None or str(value).strip().lower() in [
        "n/a", "none", "null", "nan", "",
    ]

def compute_value_metrics(y_true, y_pred):
    tp = fn = fp = tn = 0

    for expected_value, generated_value in zip(y_true, y_pred):
        # Case-1: True Positive
        expected_null = is_null_like(expected_value)
        generated_null = is_null_like(generated_value)

        if not expected_null and not generated_null:
            # Fuzzy match instead of exact match
            similarity = difflib.SequenceMatcher(
                None, 
                str(expected_value), 
                str(generated_value)
            ).ratio()
            if similarity > 0.8:
                tp += 1
            else:
                fp += 1

        # Case-2: False Positive
        if not generated_null and expected_null:
            fp += 1

        # Case-3: False Negative
        if generated_null and not expected_null:
            fn += 1

        # Case-4: True Negative
        if generated_null and expected_null:
            tn += 1

    # Calculate additional metrics
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1_score = (
        2 * (precision * recall) / (precision + recall) 
        if (precision + recall) > 0 
        else 0
    )

    return {
        "true_positive": tp,
        "fale_negative": fn,
        "false_positive": fp,
        "true_negative": tn,
        "precision": round(precision, 4),
        "recall": round(recall, 4),
        "f1_score": round(f1_score, 4),
    }

y_true = comp_df["labels"].tolist()
y_pred = comp_df["output"].tolist()

compute_value_metrics(y_true, y_pred)

 

Now, let’s look at the results.

Baseline evaluation metricsBaseline evaluation metrics

We got 71% for F1-score – not terrible! But hardly a score you’d be comfortable with putting into production. 

In part 2 of this tutorial we will use the training dataset that we put aside from earlier, create more synthetic examples, and use them all to fine-tuning a smaller LLM to achieve a higher F1-score metric.

-----

OpenAI Python Library is provided under the Apache 2.0 License, Copyright 2023 OpenAI

1 Comment
KyraWulffert
Databricks Employee
Databricks Employee

Great blog! What would be your opinion about packaging the llm with the prompt and the structured output in place as a model, serving it in a PT serving endpoint and then running a batch inference with the ai_query?