cancel
Showing results for 
Search instead for 
Did you mean: 
Technical Blog
Explore in-depth articles, tutorials, and insights on data analytics and machine learning in the Databricks Technical Blog. Stay updated on industry trends, best practices, and advanced techniques.
cancel
Showing results for 
Search instead for 
Did you mean: 
KyraWulffert
Databricks Employee
Databricks Employee

Managing the bias-variance trade-off at scale

by John Karlsson, Kyra Wulffert, Maria Zervou 

 

Introduction

When building machine learning models on datasets that are segmented by categorical groups, such as various store locations or product categories, it’s common to face a recurring question: Should we train one global model or one model per category? The decision can drastically impact our model’s performance, interpretability, and maintainability.

The complexity arises from the need to balance model accuracy and generalisation. A single "global" model applied across all segments might suffer from high bias if it fails to capture the nuances specific to each group, potentially oversimplifying complex patterns. On the other hand, creating separate models for each segment can potentially reduce bias by better capturing group-specific characteristics. However, this approach increases variance and risks overfitting due to greater model complexity. This situation presents a classic bias-variance tradeoff, where reducing one often leads to an increase in the other. The challenge lies in finding an optimal approach that maintains accuracy while avoiding overfitting.

To address this challenge, we will explore two approaches to addressing the bias-variance tradeoff as presented below:

  1. Global: Train a single model on all your data at once with hyperparameter tuning and cross-validation.
  2. Segmented: Train a separate model with hyperparameter tuning and cross-validation within each group.

We will demonstrate how to implement and scale these solutions using Apache Spark™ on Databricks. In our experiments, cross-validation will guide the hyperparameter tuning, ensuring that each model configuration effectively balances bias and variance while improving generalisation across groups. Furthermore, we will address the common question of how to efficiently manage and deploy a proliferation of models. We will outline several options, each with its own advantages and limitations.

 

Experiments Setup

Data Preparation 

For demo purposes, we use a synthetically generated dataset with three numerical features, a group feature, and a binary target variable specifically constructed to highlight bias-variance trade-offs. The dataset contains 3,200 observations across 8 groups (400 samples each). The target is group-dependent: each group was generated using a different polynomial function, creating distinct feature-target relationships. For example, num_1's correlation with the target ranges from -0.395 to +0.304 across groups, while the global correlation is nearly zero (+0.008) as shown in Figure 1. This setup intentionally creates a scenario where a single global model struggles to capture diverse group-specific patterns, demonstrating when group-based modeling is essential.

KyraWulffert_1-1762271115063.png

Figure 1. Feature-Target correlations by group 

 

Training a Global Model 

We start by training a single model on the entire dataset, including the group identifier as a categorical feature. This baseline approach is simple, computationally efficient and works well if groups share similar patterns. However, unless sufficiently expressive, it may underperform when predictive relationships differ substantially across groups.

In this blog, we’ve used Optuna to perform the hyperparameter tuning. Optuna is a powerful open source hyperparameter optimization framework that can be combined with Databricks and Apache Spark™ to offer a robust solution for distributed hyperparameter tuning. Optuna is pre-installed in Databricks ML Runtime, so no extra installation is required.

In this experiment, we optimise hyperparameters for a single global model trained on the entire dataset. The process involves:

1. Define the Search Space: Use Optuna to define the hyperparameter ranges to explore (e.g., n_estimators, max_depth, learning_rate, etc.).

2. Run Hyperparameter Search: Optuna iteratively samples hyperparameter configurations and evaluates them using a validation split. For computational efficiency, we use single-fold validation during this search phase, running 50 trials to identify the best configuration.

KyraWulffert_4-1762271306497.png

3. Retrain with Cross-Validation: Once the optimal hyperparameters are identified, we retrain the model using 5-fold cross-validation to obtain robust performance metrics that better estimate generalization performance.

4. Log Results to MLflow: Track all experiments, hyperparameters, and metrics using MLflow, including the final global model artifact.

 

Training a Segmented Model

Different groups (e.g., geographic regions, customer types) can exhibit markedly different feature-target relationships, so a single global model may fail to capture group-specific patterns. To address this, we train a separate model for each group, performing hyperparameter tuning individually for each segment to account for differences in data characteristics and improve per-group performance. By searching the parameter space per group, we can tailor model complexity to each segment, reducing bias while controlling variance even in smaller groups. Careful tuning, including techniques such as cross-validation, helps balance model complexity and mitigate overfitting, resulting in improved overall generalization.

KyraWulffert_5-1762271433181.png

Training one model per group can be broken down into 5 main objectives:

1. Distribute the search by group: For each group's data subset, run the Optuna search in parallel using Spark's groupBy and applyInPandas (50 trials per group). During this phase, we use single-fold validation for speed.

2. Retrain with Cross-Validation: Once optimal hyperparameters are identified for each group, retrain each model using 5-fold cross-validation to obtain robust performance metrics.

3. MLflow Tracking: Each group’s best hyperparameters are identified and logged to MLflow as child runs under a parent "Optimal Wrapper Model" run.

4. Package the best models across all groups: Aggregate the best models into a single “wrapper model,” making it convenient to load or deploy a single entity that internally knows which model to apply for each group.
KyraWulffert_0-1762271043389.png

5. Gather results: Collect and compare metrics and run information (e.g., F1 scores, best hyperparameters) from MLflow. Each group has its own associated model along in a hierarchy that also includes the explored hyperparameters.

KyraWulffert_6-1762271492762.png

Figure 2. Model metrics per group with optimised hyperparameters per model

 

Overall Experiment Evaluation

In our experiments, per-group hyperparameter tuning consistently outperformed the global model and untuned per-group approaches. Not only did each group achieve a higher F1-score when its parameters were individually optimised, but overall performance aggregated across the entire dataset also improved. 

KyraWulffert_1-1762272175325.png

Figure 3. Global and per group models metrics for all experiments. 

 

Packaging Multiple Models

After training multiple group-specific models, the question becomes: How do we package them in a way that’s easy to deploy, monitor, and maintain? Typically, the following strategies dominate:

  • Static Linking Grouped Packaging: All group-specific models are bundled together into a single atomic MLflow Model and are served together. Once created, these artifacts never change, simplifying version control and monitoring. This is the pattern used in the blog.
  • Dynamic Linking Grouped Packaging: All group-specific models are loaded by a parent MLflow Model which holds references to the group-specific model URIs. Upon each instantiation or deployment of the parent model the referenced group-specific models may have changed.

In both dynamic and static linking approaches, multiple models are grouped and served into a single serving endpoint. The key difference lies in how the artifacts are loaded and managed. With dynamic linking, the parent model references child models using URIs (e.g., 'models:/' or Unity Catalog model aliases) rather than directly including them as artifacts. This allows for updating child models without re-logging the parent model. When the endpoint is initialized, it loads the parent model, which then dynamically loads the referenced child models. If a child model is updated, the endpoint can use the latest version by simply reloading the parent model, without requiring a full redeployment. 

Static linking, on the other hand, packages all models (parent and children) into a single artifact. This creates an atomic unit where all components are versioned together, simplifying deployment and ensuring consistency across all models. However, any update to a child model requires re-logging and redeploying the entire package. It is more suitable for environments prioritising simplicity and consistency, as it simplifies monitoring, debugging, and reproducibility. In our demo, we implemented static packaging using the WrapperModel() class. 

KyraWulffert_0-1762272093267.png

Figure 4. Static Model Packaging in our demo. 

 

Spark Performance Tuning

Pandas UDFs (User Defined Functions) in Apache Spark™ distribute tasks across a cluster in a way that leverages both the distributed nature of Spark and the efficiency of Pandas operations. When Spark executes a groupBy operation, it triggers a shuffle so that all records belonging to the same group are co-located on the same partition, enabling grouped computation across the cluster. In this blog, UDFs are used both during training across multiple groups as well as during batch (or streaming) inference across non-grouped data.

When running distributed training and inference, performance depends heavily on how Spark distributes tasks and resources. This section describes tuning of Spark’s configuration to balance CPU, memory, and shuffle efficiency.

  • Tune CPU allocation per task - Monitor executor CPU usage. If tasks are consistently running near 100% utilization, they are likely CPU-bound. If the model supports multi-threaded training, increase the number of CPUs per task to improve performance in the Spark config of the cluster settings. This requires a cluster restart to take effect. 
    spark.task.cpus 2

 

  • Disable Adaptive Query Execution (AQE) - AQE can dynamically merge small partitions after shuffle operations, which may override manual repartitioning and reduce parallelism. Disable it to retain precise control over partitioning: 
    spark.conf.set("spark.sql.adaptive.enabled", "false")
  • Batch Size for Inference - When applying a model UDF (e.g. spark.pyfunc.spark_udf) in batch or streaming inference, Spark converts data partitions into Arrow record batches. If your model’s memory footprint during inference is high and dependent on the batch size (default 10,000) you may configure a lower value. Note that this setting does not affect the groupBy operation during training.
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1000")  # For models with a high inference memory footprint
  • Memory usage in groupBy operations - During groupBy, all data for each group is loaded into memory. Large groups can cause memory bottlenecks, spills to disk, or OutOfMemory errors. Monitor task-level memory usage in the Spark UI and increase executor memory if needed.

 

Conclusion

In this blog, we demonstrated how to tackle the bias-variance trade-off by training and tuning multiple models, one for each group, in a distributed setting. We compared three approaches: a global model, segmented models without hyperparameter optimization, and optimal models with group-specific hyperparameter tuning. The results consistently showed that per-group hyperparameter tuning outperformed both the global model and untuned per-group approaches, achieving higher F1-score for each group and improved overall performance across the entire dataset. Thus, we managed to prove that this approach effectively balanced the need for capturing group-specific nuances while avoiding overfitting. 

The implementation leveraged Apache Spark™ on Databricks for distributed computing, Optuna for hyperparameter optimization, and MLflow for experiment tracking and model packaging. You can easily swap the distributed framework with Ray on Spark and use groupby('model_num').map_groups in Ray. For distributed hyperparameter tuning, you could also use Ray Tune with Optuna or Hyperopt. The common theme across these approaches is the pattern of grouping the data on the key and then applying the training and tuning functions in parallel across groups. Finally, the blog also addressed the challenge of packaging multiple models, focusing on a static linking approach for deployment, when serving them as a unified solution is a key requirement. 

Interesting Reads