cancel
Showing results for 
Search instead for 
Did you mean: 
Technical Blog
Explore in-depth articles, tutorials, and insights on data analytics and machine learning in the Databricks Technical Blog. Stay updated on industry trends, best practices, and advanced techniques.
cancel
Showing results for 
Search instead for 
Did you mean: 
li_yu
Databricks Employee
Databricks Employee

As machine learning (ML) workloads continue to grow in complexity and scale, organizations are looking for efficient and scalable solutions to manage their ML lifecycle. Databricks offers a powerful platform for ML workloads, providing scalability, security, and collaboration. In this blog post, we'll explore how to migrate ML workloads to Databricks using financial use cases as examples. We'll discuss optimization techniques, cluster management, and model lifecycle management using Databricks' Unity Catalog (UC).

mlmodel_udf - Page 1.png

Model Training and Registration

Migrating ML workloads to Databricks offers the advantage of bringing distributed compute and collaborative data scientists under one unified platform. One of the most powerful features of Spark on Databricks is the ability to parallelize model training across multiple subsets of data. This is especially useful for scenarios where each group (such as a currency pair, customer segment, or other logical partition) needs its own specialized model due to unique data patterns. By leveraging Spark’s group-by operations and Pandas UDFs, each group’s data can be processed in parallel, significantly reducing end-to-end training time.

The example below demonstrates how to use applyInPandas to train a separate anomaly detection model for each currency symbol. In the example, each grouped subset of data is distributed across a Databricks cluster, so you can fit many models simultaneously. When combined with MLflow, each trained model is versioned, tracked, and governed under Unity Catalog, ensuring you have an auditable history of each model and the ability to manage permissions and lineage in a centralized manner.

input_example = pd.DataFrame({
   "spread": [0.01]
})
output_example = {
   "anomaly": 1
}
signature = infer_signature(input_example, output_example)
# Spark schema for the returned DataFrame from this Pandas function
output_schema = StructType([
   StructField("symbol", StringType(), True),
   StructField("model_version_info", StringType(), True),
])

def train_and_log_anomaly_model(pdf: pd.DataFrame) -> pd.DataFrame:
   """
   Each PDF is the subset of rows for a single symbol group.
   We'll train an IsolationForest and log it to MLflow (Unity Catalog).
   """
   symbol_value = pdf["symbol"].iloc[0]
   pdf["spread"] = pdf["ask_price"] - pdf["bid_price"]


   # Train a simple IsolationForest on the spread
   X = pdf[["spread"]]
   iso_model = IsolationForest(random_state=42)
   iso_model.fit(X)
   uc_model_name = f"my_catalog.my_schema.fx_spread_{symbol_value.replace('/','_')}"

   with mlflow.start_run(nested=True) as run:
       mlflow.sklearn.log_model(
           sk_model=iso_model,
           artifact_path="model",
           registered_model_name=uc_model_name,
           signature=signature,          
           input_example=input_example,  
       )

       mlflow.log_param("symbol", symbol_value)
       mlflow.log_metric("num_records", len(pdf))
       run_id = run.info.run_id
       version_info = f"symbol={symbol_value}, run_id={run_id}"

   return pd.DataFrame([[symbol_value, version_info]], columns=["symbol","model_version_info"])
result_df = (
   df.groupBy("symbol")
     .applyInPandas(train_and_log_anomaly_model, schema=output_schema)
)

One common issue with training models on spark is not having enough memory on the executor if a data partition is large. In this case, spark.task.cpus can be increased to allocate more memory per task. For example, if spark.task.cpus is set to be the number of cores on a worker node, all memory of a work node will be allocated to one task. Worker nodes with more memories will also help with the situation.

Logging Models to Unity Catalog

One of the big wins from migrating to Databricks is to be able to govern model lifecycles with MLflow and Unity Catalog. MLflow is a powerful tool that logs artifacts such as trained models, metrics, and parameters. By storing these artifacts inside Unity Catalog, you can benefit from fine-grained access controls and enterprise governance for all your ML assets. Each model is registered with a fully qualified name (e.g., my_catalog.my_schema.model_name).

The code snippet uses mlflow.sklearn.log_model to register each trained model in Unity Catalog with an explicit model name derived from the currency symbol, and it includes a model signature and input/output examples, which enables data type enforcement and make it straightforward for others to understand how to call the model at inference time.

In Unity Catalog, model aliases have largely replaced the traditional concept of model versions or stages for managing ML model lifecycles. Rather than using fixed stages as in the legacy Workspace Model Registry, UC utilizes a more flexible approach with model aliases like "Champion" and "Challenger" to designate model deployment status. The following code assigns the latest version of model with “Champion” alias after it demonstrates better performance than the current model in production.

from mlflow import MlflowClient
client = MlflowClient()

client.set_registered_model_alias(uc_model_name, "Champion", model_info.registered_model_version)

Distributed Inference with mapInPandas

Migrating to Databricks also enables parallel processing for inference. The code example below highlights how mapInPandas can be used to load each model dynamically from Unity Catalog and apply it to the corresponding partition of data. 

To increase the parallelism of a spark job and reduce data skew, repartition(partitionNum) can be used to adjust the number of partitions. The partitionNum is recommended to be around 2 to 4 times the number of cores in the cluster. With more partitions, tasks can be distributed more evenly across the cluster. While more partitions can lead to better parallelism, too many partitions can also increase overhead due to excessive shuffling and scheduling small tasks.

Spark partitions are transformed into Arrow record batches during processing. Each batch is passed to a pandas UDF as a pandas dataframe as shown in predict_spread_anomalies below. To reduce the risk of running out of memory, you can limit how many rows go into each Arrow batch by adjusting the spark.sql.execution.arrow.maxRecordsPerBatch configuration. The default number of records per partition is 10,000. If the number of rows per partition is less than this number, there will be only one batch per partition.

inference_schema = StructType([
   StructField("symbol", StringType(), True),
   StructField("timestamp", TimestampType(), True),
   StructField("bid_price", DoubleType(), True),
   StructField("ask_price", DoubleType(), True),
   StructField("spread", DoubleType(), True),
   StructField("prediction", DoubleType(), True)])

def predict_spread_anomalies(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
   for pdf in iterator:
       if pdf.empty:
           yield pdf
           continue
       symbol_val = pdf["symbol"].iloc[0]
       safe_symbol = symbol_val.replace("/", "_")
      
       uc_model_name = f"models:/my_catalog.my_schema.fx_spread_{safe_symbol}@Champion"
       model = mlflow.pyfunc.load_model(uc_model_name)
       pdf["spread"] = pdf["ask_price"] - pdf["bid_price"]
       preds = model.predict(pdf[["spread"]])  
       pdf["prediction"] = preds
       yield pdf

repart_df = df.repartition("symbol")
inference_df = repart_df.mapInPandas(
   predict_spread_anomalies,
   schema=inference_schema
)

 

Monte Carlo Simulation

In the financial services industry, Monte Carlo simulation (e.g. loan risk management) can be migrated from SAS with the same approach. By implementing these simulations as workflows or jobs, they can leverage Databricks compute resources to scale linearly according to the number of loans, scenarios, and simulation runs.

The code example below illustrates how monte carlo simulations can be implemented. A dummy portfolio dataframe is created to contain 1000 loans and then exploded with 8 scenarios. The real scenario data can be joined with the scenario ids for simulations. In the example, Pandas UDF is utilized to apply the process_simulation function to the portfolio dataframe. 

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType,StructField,IntegerType,DoubleType
import numpy as np
import pandas as pd

num_loans = 1000
num_scenarios = 8
num_simulations = 65
num_partitions = 64 # 2 or 4 times of total cores

loans_pdf = pd.DataFrame({
   "loan_id": range(1, num_loans + 1),
   "principal_amount": np.random.randint(10000, 100000, size=num_loans),
   "interest_rate": np.random.uniform(0.01, 0.15, size=num_loans),
})
loans_df = spark.createDataFrame(loans_pdf)

scenarios_pdf = pd.DataFrame({
   "scenario_id": range(1, num_scenarios + 1),
   "scenario_factor": np.random.uniform(0.9, 1.1, size=num_scenarios),
})
scenarios_df = spark.createDataFrame(scenarios_pdf)
portfolio_df = loans_df.crossJoin(scenarios_df)

def monte_carlo_batch(iterator):
   """
   Each Spark partition is split into one or more Arrow/Pandas batches.
   For each batch, we call 'process_simulation' and yield the result.
   """
   for pdf in iterator:
       out_pdf = process_simulation(pdf, num_simulations)
       yield out_pdf

expanded_schema = (
   StructType()
   .add("loan_id", IntegerType(), nullable=True)
   .add("scenario_id", IntegerType(), nullable=True)
   .add("simulation_id", IntegerType(), nullable=True)
   .add("random_draw", DoubleType(), nullable=True)
   .add("simulated_loss", DoubleType(), nullable=True)
)

# Apply the function using mapInPandas
expanded_df = portfolio_df.repartition(num_partitions).mapInPandas(monte_carlo_batch, schema=expanded_schema)

In this example, each partition of data will be serialized and passed to the monte_carlo_batch function. Each partition will be further divided into an iterator of pandas dataframes (batches). The size of a batch can be defined using spark.sql.execution.arrow.maxRecordsPerBatch.

If there are 8000 rows/loans (8 scenarios with 1000 loans per scenario) partitioned into 64 partitions, each partition will contain ~125 rows. In the screenshot below, the spark stage contains 64 tasks with each task having ~125 shuffle read records. All partitions/tasks are evenly distributed onto 4 executors/workers as shown in the screenshot below.

mc_stages.png

Summary

Migrating ML workloads to Databricks offers a streamlined experience for data engineering, analytics, and model lifecycle management. By taking advantage of Databricks distributed compute, DS can efficiently train and serve ML models at scale — an approach especially beneficial for financial services scenarios like loan risk assessment or FX anomaly detection. MLflow seamlessly integrates with Unity Catalog to provide robust model versioning and governance for the ML models trained and deployed within Databricks environement.