Echoing what we said in Part 1: Test Data Curation, when your team is migrating script code and pipeline code to Databricks, there are three main steps:
This article focuses on the third point. So, to ensure that a script is converted correctly you typically want to make sure:
In Part 1 we focused on how to curate comprehensive data to use for running tests. The solution we landed on was to create:
In Part 2, we will focus on how to set up the test suites themselves. Some questions we will answer are:
When it comes to test suites, the most important aspects are:
The great thing about our data curation approach was that we brought all the data into Databricks which will tell us if the test suite is passing. There is no need to connect (via JDBC or otherwise) to the source system to get this information. As such, the requirements for an excellent test suite are:
Now that we’ve established what we need, let’s start from the user’s perspective and dive deeper one step at a time.
For running the test suite, we suggest you devote one notebook for each pipeline being run. For instance a pipeline with two scripts in it may have notebook code that looks like this:
# Databricks notebook source
from tests.unit_tests.testing_utils import *
repo_path = "<path to your scripts being converted>"
timeout_mins = 60*15
job = "products_monthly"
# COMMAND ----------
# DBTITLE 1,1 - products_backup
script = "products_backup.sql"
notebook_args = {
"SCHEMA" : "products_monthly__products_backup",
"CATALOG" : "test_suites",
"STAGING" : "abfss://staging@devstorageaccount.dfs.core.windows.net/input/"
}
dbutils.notebook.run(repo_path + "raw/products_backup", timeout_mins, notebook_args)
test_1_results = validate_and_reset("products_monthly", "products_backup")
# COMMAND ----------
# DBTITLE 1,2 - products_load
script = "products_load.sql"
notebook_args = {
"SCHEMA" : "products_monthly__products_load",
"CATALOG" : "test_suites",
"STAGING" : "abfss://staging@devstorageaccount.dfs.core.windows.net/input/"
}
dbutils.notebook.run(repo_path + "raw/products_load", timeout_mins, notebook_args)
test_2_results = validate_and_reset("products_monthly", "products_load")
Let’s break this down. Each cell will:
As discussed above, each notebook runs with its own devoted schema, this enables many developers to work on scripts at the same time.
Digging one level deeper, let’s think about what needs to happen in the validation function validate_and_reset()
. We need to:
<job>__<script>__%
. .*__.*__(.*)__post
. Compare the schema, counts, and row contents for each table.RESTORE
on each changed table in the test schema that changed (i.e. tables which have a post table).How do we actually compare the rows between two tables (or DataFrames)? At Databricks when we run these migrations ourselves we consider two options:
exceptAll
returns zero records (and the count matches) then you know the DataFrames are identical.Code to achieve table validation (using exceptAll
) looks like this:
from pyspark.sql.functions import col, min
from databricks.sdk.runtime import *
import re
def set_schema(job_name: str, script_name: str, catalog: str = "dev_test_suites"):
spark.sql(f"USE {catalog}.{job_name}__{script_name}")
print(f"Schema set to: `{catalog}.{job_name}__{script_name}`")
def get_post_tables(job_name: str, script_name: str, script_num: int, post_table_schema: str = "post_tables", catalog: str = "dev_test_suites"):
# grab the post tables and coerce to list[str]
post_tables_rows = spark.sql(f"show tables in {catalog}.{post_table_schema}") \
.filter(col("tableName").like(f'{job_name.lower()}__{script_name.lower()}%') | col("tableName").like(f'{job_name.lower()}__{str(script_num)}%')) \
.select("tableName").collect()
return [post_table_row[0] for post_table_row in post_tables_rows]
def run_unit_test(post_tables: str, post_table_schema: str, catalog: str):
out = {}
if len(post_tables) == 0:
print("No tables to compare. Investigate output files before marking as passing")
return out
else:
for table in post_tables:
post_table = table
pre_table = post_table.split("__")[2]
out[pre_table] = {}
test_results = [False, False, False]
# schema check
pre_schema = spark.sql(f"select * from {pre_table}").drop("rowid").schema
post_schema = spark.sql(f"select * from {catalog}.{post_table_schema}.{post_table}").drop("rowid").schema
# check if the schema length is equal
if len(pre_schema) == len(post_schema):
# check if the column names and datatypes are identical
diff_columns = [(col_pre, col_post) for col_pre, col_post in zip(pre_schema, post_schema) if col_pre.name != col_post.name or col_pre.dataType != col_post.dataType]
if len(diff_columns) == 0:
out[pre_table]["schema"] = "passed"
test_results[0] = True
else:
print("Some schema columns not matching")
out[pre_table]["schema"] = diff_columns
else:
out[pre_table]["schema"] = f"Schema lengths not equal, pre is {len(pre_schema)} fields but post has {len(post_schema)}"
# row count check
pre_count = spark.sql(f"select count(*) from {pre_table}").collect()[0][0]
post_count = spark.sql(f"select count(*) from {catalog}.{post_table_schema}.{post_table}").collect()[0][0]
if pre_count != post_count:
print(f"Table count mismatched, {pre_table} count is : {pre_count} and {post_table} count is : {post_count}")
out[pre_table]["counts"] = f"failed - Table count mismatched, {pre_table} count is : {pre_count} and {post_table} count is : {post_count}"
else:
out[pre_table]["counts"] = "passed"
test_results[1] = True
## data check
pre_data = spark.sql(f"select * from {pre_table}").drop("rowid").drop("ROWID")
post_data = spark.sql(f"select * from {catalog}.{post_table_schema}.{post_table}").drop("rowid").drop("ROWID")
data_mismatch = pre_data.exceptAll(post_data)
mismatch_count = data_mismatch.count()
if mismatch_count != 0:
print(f"Data check failed. {pre_table} and {post_table} has {mismatch_count} rows different")
# avoid large collects to driver
if mismatch_count > 1000000:
print(f"\tfound more than 1M mis-matching records so only saving first 1M in out['{pre_table}']['data_check']")
data_mismatch = data_mismatch.limit(1000000)
out[pre_table]["data_check"] = data_mismatch.toPandas()
else:
out[pre_table]["data_check"] = "passed"
test_results[2] = True
if all(test_results):
print(f"Unit test successfully passed for {pre_table} and {post_table}")
out[pre_table] = "passed"
# check if all tests passed:
if all([state == "passed" for tbl, state in out.items()]):
print("All tests passed!")
return out
# define function to do the Delta table RESTORE
def restore_from_post_table_names(post_tables: list[str]) -> None:
for table in post_tables:
pre_table = table.split("__")[2]
print(f"Restoring {pre_table}...", end="")
first_version = spark.sql(f"DESCRIBE HISTORY {pre_table}").orderBy(col("version").asc()).first().asDict()
if first_version["version"] == 0:
first_operation = spark.sql(f"DESCRIBE HISTORY {pre_table}").filter("version = 0").collect()[0].asDict()['operation']
def get_first_write(schema_table):
# get the first write, which in some cases may not exist
first_write = spark.sql(f"DESCRIBE HISTORY {schema_table}").filter("operation = 'WRITE'").groupBy().agg(min("version")).collect()[0][0]
# handle the case where it's just a clone and truncate or just a clone
if first_write:
return first_write
else:
# if there are no writes, we just go to the latest operation
version = spark.sql(f"DESCRIBE HISTORY {schema_table}").count() - 1
return version
restore_version = get_first_write(pre_table) if first_operation == "CLONE" else 0
spark.sql(f"RESTORE TABLE {pre_table} TO VERSION AS OF {restore_version}")
print(f"Done")
else:
print(f"corrupted because first version is {first_version['version']}...", end="")
spark.sql(f"RESTORE TABLE {pre_table} TO VERSION AS OF {first_version['version']}")
print(f"restored to version {first_version['version']}, {first_version['operation']}")
def validate_and_reset(job_name: str, script_name: str, script_num: int = 0, post_table_schema: str = "post_tables", catalog: str = "dev_test_suites"):
set_schema(job_name, script_name, catalog)
# grab the post tables and coerce to list[str]
post_tables = get_post_tables(job_name, script_name, script_num, post_table_schema, catalog)
# run the unit test safely to compare post tables
results = run_unit_test(post_tables, post_table_schema, catalog)
restore_from_post_table_names(post_tables)
return results
We may need to compare output flat files as well, such as csv or other delimited files. This is complicated because we have not discussed an effective way of bringing these files into Databricks. Assuming you have the source of truth files (post files) that were generated by the source system and want to ensure that the files generated by the test suite (pre files) match, you can use this approach:
exceptAll
to compare whether all the records from one data frame are also in the other. If you have join keys you can also do it that wayA simplified version of code that achieves this looks like:
import re
def compare_output_file(file_name: str, path: str, validation_path: str):
post_file = spark.read.option("sep", "!!!!!").csv(validation_path + file_name)
pre_file = spark.read.option("sep", "!!!!!").csv(path + file_name)
pre_count = pre_file.count()
post_count = post_file.count()
if pre_count != post_count:
raise Exception(f" File count mismatched, pre file count is : {pre_count} and post file count is : {post_count}")
else:
print(f"Counts match: {pre_count}")
data_mismatch = pre_file.exceptAll(post_file)
mismatch_count = data_mismatch.count()
if mismatch_count == 0:
print("Test passed!")
else:
print(f"Data mismatch ({mismatch_count} rows), displaying mismatches")
display(data_mismatch)
Now that we have the code to be able to run an effective test suite, let’s talk about how to automate the creation of these notebooks. The approach is:
def write_header(job: str, job_num: int, sprint_num: int):
job_name = job.removeprefix('nz.').removesuffix('.job')
job_file = f"all_unit_tests/sprint_{str(sprint_num)}/{str(job_num).zfill(3)} - {job_name}.py"
with open(job_file, "x") as sink:
header = f"""# Databricks notebook source
from tests.unit_tests.testing_utils import *
repo_path = \"<repo path>/sprint_{str(sprint_num)}/\"
timeout_mins = 60*15
job = \"{job}\""""
sink.write(header)
return job_file
def get_widgets(script_path, job, script) -> list[str]:
pattern = r'.*=dbutils.widgets.*'
# Read the file
with open(script_path.lower().removesuffix('.sql') + '.py') as file:
lines = file.readlines()
# get widgets
widgets = [re.match(r'(.*)=dbutils.widgets.*', line).group(1) for line in lines if re.match(pattern, line)]
if "catchup" in job:
schema = f"{job.replace('.', '_')}__{script.replace('.', '_')}"
else:
catchup_script = script.replace('.', '_').replace('_sql', f'{job.removeprefix("nz.").removesuffix(".job")}_sql'.replace('.', '_'))
schema = f"{job.replace('.', '_')}__{catchup_script}"
out = []
for widget in widgets:
match widget:
case 'CATALOG':
out.append("\"CATALOG\": \"dev_test_suites\"")
case 'SCHEMA':
out.append(f"\"SCHEMA\": \"{schema}\"")
case 'STAGING':
out.append("\"STAGING\": \"abfss://staging@devstorageaccount.dfs.core.windows.net/input/\"")
case 'INPUT':
out.append("\"INPUT\": \"abfss://staging@devstorageaccount.dfs.core.windows.net/input/\"")
case 'CURRENT_DATE':
out.append("\"CURRENT_DATE\": run_date")
case 'PIPELINERUNID':
pass
case 'PIPELINE_NAME':
pass
case 'ACTIVITY_NAME':
pass
case w:
out.append(f"\"{w}\" : \"<add>\"")
return ',\n\t'.join(out)
def write_test_caller(script: str, script_num: int, job: str, job_file: str, raw_script_path: str):
script_path = raw_script_path + script
widgets = get_widgets(script_path, job, script)
with open(job_file, "a") as sink:
header = f"""
# COMMAND ----------
# DBTITLE 1,{str(script_num)} - {script.removesuffix('.sql')}
script = \"{script}\"
notebook_args = {{
{widgets}
}}
dbutils.notebook.run(repo_path + "raw/" + "{script.removesuffix(".sql").lower()}", timeout_mins, notebook_args)
test_{script_num}_results = validate_and_reset(job.replace('.', '_'), "{script.replace('.', '_')}")"""
sink.write(header)
My suggestion is to find a way to completely automate the creation of these notebooks and then add complexity to do things like compare output files and meet the unique needs of your migration.
In Part 2 of this series we looked at how to run validations. Like Part 1, the code provided is just detailed inspiration, provided to get your team thinking about implementation and edge cases. You will need to take this code and run with it, and it may end up looking fairly different. Whatever you do, if you have a lot of pipelines and scripts, automate the process as best you can. In the recent migration I developed this code for, we were running over 200 pipelines with 1700+ scripts being called. Being able to create and deploy these notebooks each sprint in 15-20 minutes while fixing edge-cases as I went along was instrumental in keeping our team executing at a rapid pace and keeping up with our aggressive timelines.
When it comes to the row comparison code, we discussed using datacompy versus exceptAll
. If there is any way to get the join keys, I recommend you go with the datacompy approach. Not only is it convenient to use a pre-built package, but it makes reconciliation when tests pass much easier. You will be able to identify the exact row that did not have a match in the other table. With exceptAll
we needed to create custom code that would do this, but it was not correct in cases (think if a table only had one column to begin with).
Thanks for reading and happy coding!
You must be a registered user to add a comment. If you've already registered, sign in. Otherwise, register and sign in.