cancel
Showing results for 
Search instead for 
Did you mean: 
Technical Blog
Explore in-depth articles, tutorials, and insights on data analytics and machine learning in the Databricks Technical Blog. Stay updated on industry trends, best practices, and advanced techniques.
cancel
Showing results for 
Search instead for 
Did you mean: 
Dan_Z
Databricks Employee
Databricks Employee

Echoing what we said in Part 1: Test Data Curation, when your team is migrating script code and pipeline code to Databricks, there are three main steps:

  1. Look at the original code, understand what it's doing
  2. Convert the code to run on Databricks (convert SQL to Spark SQL or other code to Python/Scala/etc.)
  3. Test the code to make sure it works, fixing as needed

This article focuses on the third point. So, to ensure that a script is converted correctly you typically want to make sure:

  • the code compiles
  • there are no runtime errors
  • any updated tables match what the previous system would have created
  • any output files match what the previous system would have created

In Part 1 we focused on how to curate comprehensive data to use for running tests. The solution we landed on was to create:

  • One devoted schema (or database) for each test with all the tables and views (pre tables) needed for that script to run.
  • A separate schema in Databricks which held all the post tables which we reference to see if tests were successful.

In Part 2, we will focus on how to set up the test suites themselves. Some questions we will answer are:

  • How will we know if a test suite is passing? What metrics will be used?
  • How will two tables be deemed the same? 
  • How can we automate the creation of test suites?
  • How do we compare output files?
  • What about scalability?

Solution Requirements

When it comes to test suites, the most important aspects are:

  • Easy to use
  • Quick to run

The great thing about our data curation approach was that we brought all the data into Databricks which will tell us if the test suite is passing. There is no need to connect (via JDBC or otherwise) to the source system to get this information. As such, the requirements for an excellent test suite are:

  • Run in Databricks Notebooks: so that setup is minimal for the conversion engineers and tests be deployed/maintained in a central location.
    • The creation of these test suite notebooks should be automated.
  • “One click” runs: it should only require running a single function (or notebook cell) to see if the test succeeds or fails.
  • Compare all changed tables: if a script or pipeline changes a table, it should check to see if it’s correct.
    • Columns match
    • Row counts match
    • Row contents match
  • Compare output files: if a script or pipeline outputs a file, it should check to see if it’s correct.
    • Columns match
    • Row contents match

Now that we’ve established what we need, let’s start from the user’s perspective and dive deeper one step at a time.

Test Suite Code

For running the test suite, we suggest you devote one notebook for each pipeline being run. For instance a pipeline with two scripts in it may have notebook code that looks like this:

 

# Databricks notebook source
from tests.unit_tests.testing_utils import *
repo_path = "<path to your scripts being converted>"
timeout_mins = 60*15
job = "products_monthly"

# COMMAND ----------

# DBTITLE 1,1 - products_backup
script = "products_backup.sql"
notebook_args = {
	"SCHEMA" : "products_monthly__products_backup",
	"CATALOG" : "test_suites",
	"STAGING" : "abfss://staging@devstorageaccount.dfs.core.windows.net/input/"
  }
dbutils.notebook.run(repo_path + "raw/products_backup", timeout_mins, notebook_args)
test_1_results = validate_and_reset("products_monthly", "products_backup")

# COMMAND ----------

# DBTITLE 1,2 - products_load
script = "products_load.sql"
notebook_args = {
	"SCHEMA" : "products_monthly__products_load",
	"CATALOG" : "test_suites",
	"STAGING" : "abfss://staging@devstorageaccount.dfs.core.windows.net/input/"
  }
dbutils.notebook.run(repo_path + "raw/products_load", timeout_mins, notebook_args)
test_2_results = validate_and_reset("products_monthly", "products_load")

 

Let’s break this down. Each cell will:

  • Be devoted to one script.
  • Defined the input parameters for the notebook being tested.
    • Configuring the scripts to parameterize the catalog and schema is a best practice and makes code promotion simple.
  • Runs the notebook being tested with the defined parameters. This will put the test suite schema for that notebook in the post state.
  • After the notebook runs, a validation function will check to see if the test passed and then resets the test suite schema to its pre state.

As discussed above, each notebook runs with its own devoted schema, this enables many developers to work on scripts at the same time. 

 

Validation Code

Validation Approach

Digging one level deeper, let’s think about what needs to happen in the validation function validate_and_reset(). We need to:

  1. Set the catalog and schema to the test’s catalog and schema.
  2. Grab the names of the post tables needed to run the validation, with a simple pattern search of <job>__<script>__%
  3. For each post table, compare it to the associated pre table in the test schema, which is easily accessible from the post table’s name using .*__.*__(.*)__post. Compare the schema, counts, and row contents for each table.
  4. Reset the test suite by using Delta’s RESTORE on each changed table in the test schema that changed (i.e. tables which have a post table).
  5. Compare output files if needed

Table Row Comparison Options

How do we actually compare the rows between two tables (or DataFrames)? At Databricks when we run these migrations ourselves we consider two options:

  1. If you have primary keys (or some composite join key) for each table, you can simply do a join. Many teams use the python package datacompy to automate this.
  2. If you don’t have join keys, you will need to join on every column of each row and see if any don’t have a match. The most scalable Spark-native way of doing this is using DataFrame’s exceptAll method. If exceptAll returns zero records (and the count matches) then you know the DataFrames are identical.

Code to achieve table validation (using exceptAll) looks like this:

 

from pyspark.sql.functions import col, min
from databricks.sdk.runtime import *
import re


def set_schema(job_name: str, script_name: str, catalog: str = "dev_test_suites"):
    spark.sql(f"USE {catalog}.{job_name}__{script_name}")
    print(f"Schema set to: `{catalog}.{job_name}__{script_name}`")


def get_post_tables(job_name: str, script_name: str, script_num: int, post_table_schema: str = "post_tables", catalog: str = "dev_test_suites"):
    
    # grab the post tables and coerce to list[str]
    post_tables_rows = spark.sql(f"show tables in {catalog}.{post_table_schema}") \
        .filter(col("tableName").like(f'{job_name.lower()}__{script_name.lower()}%') | col("tableName").like(f'{job_name.lower()}__{str(script_num)}%')) \
        .select("tableName").collect()
    return [post_table_row[0] for post_table_row in post_tables_rows]


def run_unit_test(post_tables: str, post_table_schema: str, catalog: str):
    
    out = {}
    if len(post_tables) == 0:
        print("No tables to compare. Investigate output files before marking as passing")
        return out
    else:
        for table in post_tables:
            post_table = table
            pre_table = post_table.split("__")[2]
            out[pre_table] = {}
            test_results = [False, False, False]

            # schema check
            pre_schema = spark.sql(f"select * from {pre_table}").drop("rowid").schema
            post_schema = spark.sql(f"select * from {catalog}.{post_table_schema}.{post_table}").drop("rowid").schema

            # check if the schema length is equal
            if len(pre_schema) == len(post_schema):
                # check if the column names and datatypes are identical
                diff_columns = [(col_pre, col_post) for col_pre, col_post in zip(pre_schema, post_schema) if col_pre.name != col_post.name or col_pre.dataType != col_post.dataType]
                if len(diff_columns) == 0:
                    out[pre_table]["schema"] = "passed"
                    test_results[0] = True
                else:
                    print("Some schema columns not matching")
                    out[pre_table]["schema"] = diff_columns 
            else:
                out[pre_table]["schema"] = f"Schema lengths not equal, pre is {len(pre_schema)} fields but post has {len(post_schema)}"

            # row count check
            pre_count = spark.sql(f"select count(*) from {pre_table}").collect()[0][0]
            post_count = spark.sql(f"select count(*) from {catalog}.{post_table_schema}.{post_table}").collect()[0][0]
            if pre_count != post_count:
                print(f"Table count mismatched, {pre_table} count is : {pre_count} and {post_table} count is : {post_count}")
                out[pre_table]["counts"] = f"failed - Table count mismatched, {pre_table} count is : {pre_count} and {post_table} count is : {post_count}"
            else:
                out[pre_table]["counts"] = "passed"
                test_results[1] = True

            ## data check
            pre_data = spark.sql(f"select * from {pre_table}").drop("rowid").drop("ROWID")
            post_data = spark.sql(f"select * from {catalog}.{post_table_schema}.{post_table}").drop("rowid").drop("ROWID")
            data_mismatch = pre_data.exceptAll(post_data) 
            mismatch_count = data_mismatch.count()

            if mismatch_count != 0:
                print(f"Data check failed. {pre_table} and {post_table} has {mismatch_count} rows different")

                # avoid large collects to driver
                if mismatch_count > 1000000:
                    print(f"\tfound more than 1M mis-matching records so only saving first 1M in out['{pre_table}']['data_check']")
                    data_mismatch = data_mismatch.limit(1000000)

                out[pre_table]["data_check"] = data_mismatch.toPandas()
            else:
                out[pre_table]["data_check"] = "passed"
                test_results[2] = True

            if all(test_results):
                print(f"Unit test successfully passed for {pre_table} and {post_table}")
                out[pre_table] = "passed"

        # check if all tests passed:
        if all([state == "passed" for tbl, state in out.items()]):
            print("All tests passed!")

        return out


# define function to do the Delta table RESTORE
def restore_from_post_table_names(post_tables: list[str]) -> None:
    for table in post_tables:
        pre_table = table.split("__")[2]
        print(f"Restoring {pre_table}...", end="")
        first_version = spark.sql(f"DESCRIBE HISTORY {pre_table}").orderBy(col("version").asc()).first().asDict()

        if first_version["version"] == 0:
            first_operation = spark.sql(f"DESCRIBE HISTORY {pre_table}").filter("version = 0").collect()[0].asDict()['operation']
            
            def get_first_write(schema_table):
                # get the first write, which in some cases may not exist
                first_write = spark.sql(f"DESCRIBE HISTORY {schema_table}").filter("operation = 'WRITE'").groupBy().agg(min("version")).collect()[0][0]
                # handle the case where it's just a clone and truncate or just a clone
                if first_write:
                    return first_write
                else:
                    # if there are no writes, we just go to the latest operation
                    version = spark.sql(f"DESCRIBE HISTORY {schema_table}").count() - 1
                    return version

            restore_version = get_first_write(pre_table) if first_operation == "CLONE" else 0
            spark.sql(f"RESTORE TABLE {pre_table} TO VERSION AS OF {restore_version}")
            print(f"Done")
        else:
            print(f"corrupted because first version is {first_version['version']}...", end="")
            spark.sql(f"RESTORE TABLE {pre_table} TO VERSION AS OF {first_version['version']}")
            print(f"restored to version {first_version['version']}, {first_version['operation']}")


def validate_and_reset(job_name: str, script_name: str, script_num: int = 0, post_table_schema: str = "post_tables", catalog: str = "dev_test_suites"):
    
    set_schema(job_name, script_name, catalog)

    # grab the post tables and coerce to list[str]
    post_tables = get_post_tables(job_name, script_name, script_num, post_table_schema, catalog)

    # run the unit test safely to compare post tables
    results = run_unit_test(post_tables, post_table_schema, catalog)
    restore_from_post_table_names(post_tables)

    return results

 


File Validation

We may need to compare output flat files as well, such as csv or other delimited files. This is complicated because we have not discussed an effective way of bringing these files into Databricks. Assuming you have the source of truth files (post files) that were generated by the source system and want to ensure that the files generated by the test suite (pre files) match, you can use this approach:

  • Read the pre and post files without a valid delimiter, so that each record comes in as a single column.
  • Compare the records
  • If the record counts match, use exceptAll to compare whether all the records from one data frame are also in the other. If you have join keys you can also do it that way

A simplified version of code that achieves this looks like:

 

import re

def compare_output_file(file_name: str, path: str, validation_path: str):

    post_file = spark.read.option("sep", "!!!!!").csv(validation_path + file_name)
    pre_file = spark.read.option("sep", "!!!!!").csv(path + file_name)
    pre_count = pre_file.count()
    post_count = post_file.count()

    if pre_count != post_count:
        raise Exception(f" File count mismatched, pre file count is : {pre_count} and post file count is : {post_count}")
    else: 
        print(f"Counts match: {pre_count}")

    data_mismatch = pre_file.exceptAll(post_file)
    mismatch_count = data_mismatch.count()
    if mismatch_count == 0:
        print("Test passed!")
    else:
        print(f"Data mismatch ({mismatch_count} rows), displaying mismatches")
        display(data_mismatch)

 


Test Notebook Generation Code

Now that we have the code to be able to run an effective test suite, let’s talk about how to automate the creation of these notebooks. The approach is:

 

  1. For each job, get the scripts being run. You can do this by parsing the pipeline code or encoding it another way.
  2. Open a file, write the header:
    def write_header(job: str, job_num: int, sprint_num: int):
        job_name = job.removeprefix('nz.').removesuffix('.job')
        job_file = f"all_unit_tests/sprint_{str(sprint_num)}/{str(job_num).zfill(3)} - {job_name}.py"
        with open(job_file, "x") as sink:
            header = f"""# Databricks notebook source
    from tests.unit_tests.testing_utils import *
    repo_path = \"<repo path>/sprint_{str(sprint_num)}/\"
    timeout_mins = 60*15
    job = \"{job}\""""
            sink.write(header)
    
        return job_file

  3. Then iterate through each script and:
    1. Get the widgets in the script being tested
    2. Write the script’s test

 

def get_widgets(script_path, job, script) -> list[str]:

    pattern = r'.*=dbutils.widgets.*'

    # Read the file
    with open(script_path.lower().removesuffix('.sql') + '.py') as file:
        lines = file.readlines()

    # get widgets
    widgets = [re.match(r'(.*)=dbutils.widgets.*', line).group(1) for line in lines if re.match(pattern, line)]

    if "catchup" in job:
        schema = f"{job.replace('.', '_')}__{script.replace('.', '_')}"
    else:
        catchup_script = script.replace('.', '_').replace('_sql', f'{job.removeprefix("nz.").removesuffix(".job")}_sql'.replace('.', '_'))
        schema = f"{job.replace('.', '_')}__{catchup_script}"

    out = []
    for widget in widgets:
        match widget:
            case 'CATALOG':
                out.append("\"CATALOG\": \"dev_test_suites\"")
            case 'SCHEMA':
                out.append(f"\"SCHEMA\": \"{schema}\"")
            case 'STAGING':
                out.append("\"STAGING\": \"abfss://staging@devstorageaccount.dfs.core.windows.net/input/\"")
            case 'INPUT':
                out.append("\"INPUT\": \"abfss://staging@devstorageaccount.dfs.core.windows.net/input/\"")
            case 'CURRENT_DATE':
                out.append("\"CURRENT_DATE\": run_date")
            case 'PIPELINERUNID':
                pass
            case 'PIPELINE_NAME':
                pass
            case 'ACTIVITY_NAME':
                pass
            case w:
                out.append(f"\"{w}\" : \"<add>\"")

    return ',\n\t'.join(out)


def write_test_caller(script: str, script_num: int, job: str, job_file: str, raw_script_path: str):
    script_path = raw_script_path + script
    widgets = get_widgets(script_path, job, script)
    with open(job_file, "a") as sink:
        header = f"""
# COMMAND ----------

# DBTITLE 1,{str(script_num)} - {script.removesuffix('.sql')}
script = \"{script}\"
notebook_args = {{
  {widgets}
  }}
dbutils.notebook.run(repo_path + "raw/" + "{script.removesuffix(".sql").lower()}", timeout_mins, notebook_args)
test_{script_num}_results = validate_and_reset(job.replace('.', '_'), "{script.replace('.', '_')}")"""
        sink.write(header)

 

My suggestion is to find a way to completely automate the creation of these notebooks and then add complexity to do things like compare output files and meet the unique needs of your migration.


Parting Thoughts

In Part 2 of this series we looked at how to run validations. Like Part 1, the code provided is just detailed inspiration, provided to get your team thinking about implementation and edge cases. You will need to take this code and run with it, and it may end up looking fairly different. Whatever you do, if you have a lot of pipelines and scripts, automate the process as best you can. In the recent migration I developed this code for, we were running over 200 pipelines with 1700+ scripts being called. Being able to create and deploy these notebooks each sprint in 15-20 minutes while fixing edge-cases as I went along was instrumental in keeping our team executing at a rapid pace and keeping up with our aggressive timelines.

When it comes to the row comparison code, we discussed using datacompy versus exceptAll. If there is any way to get the join keys, I recommend you go with the datacompy approach. Not only is it convenient to use a pre-built package, but it makes reconciliation when tests pass much easier. You will be able to identify the exact row that did not have a match in the other table. With exceptAll we needed to create custom code that would do this, but it was not correct in cases (think if a table only had one column to begin with). 

Thanks for reading and happy coding!