In Databricks, the most efficient way to handle multiple machine learning models for inference โ especially when each model has its own inference logic โ is to use batch inference with Spark DataFrames and Pandas UDFs. Instead of looping over your models sequentially in Python, you can parallelize inference across your data and model configurations using Sparkโs distributed capabilities.
Batch Inference with Spark DataFrames
Databricks recommends structuring your data in a Spark DataFrame, where each row represents an item for prediction and may include metadata indicating which model to use. The workflow typically includes:
-
Loading your data into a Spark DataFrame (from Unity Catalog, Delta tables, or external sources).
-
Loading your models from the MLflow Model Registry.
-
Creating spark UDFs for inference using mlflow.pyfunc.spark_udf().
-
Applying the UDFs to your DataFrame to generate predictions in bulk.โ
Example:
import mlflow
from pyspark.sql import functions as F
predict_udf = mlflow.pyfunc.spark_udf(spark, model_uri="models:/my_model/Production")
df = df.withColumn("prediction", predict_udf(*df.columns))
df.write.mode("overwrite").saveAsTable("predictions_output")
This approach allows Spark to distribute inference tasks across multiple executors, avoiding Pythonโs sequential bottlenecks.
Parallel Multi-Model Inference
When dealing with multiple models (e.g., per client or product), Databricks supports parallel batch inference using the groupBy.applyInPandas() method combined with Pandas UDFs. Each Spark worker can handle inference for a different model, allowing you to:
-
Load each model once per worker process.
-
Process subsets of data in parallel.โ
Example pattern:
def run_inference(pdf):
model_path = pdf['model_path'].iloc[0]
model = mlflow.pyfunc.load_model(model_path)
pdf['prediction'] = model.predict(pdf['features'])
return pdf
result_df = df.groupBy("model_id").applyInPandas(run_inference, schema=df.schema)
This design reduces redundant model loading and uses Sparkโs distributed compute layer efficiently.
Mosaic AI & AI Functions
If your inference needs involve standard ML or LLM models, you can simplify further with AI Functions or Mosaic AI batch inference. These let you run model inference directly via SQL using functions like ai_query() without manual looping or building pipelines.
Example SQL:
SELECT input_text, ai_query('my_registered_model', input_text) AS prediction
FROM my_input_table
Summary of Best Practices
| Technique |
Use Case |
Advantage |
Spark UDFs (mlflow.pyfunc.spark_udf) |
Standard ML models |
Simplifies batch scoring |
Pandas UDFs with groupBy.applyInPandas |
Many models (per group) |
Parallel per-model inference |
| Mosaic AI / AI Functions |
LLMs or unified inference |
Simplified SQL-based scaling |
| Delta Live Tables (DLT) |
Scheduled, repeatable jobs |
Automates production batch runs |
In short, replace your Python loops with Spark-level Pandas UDFs or Databricks batch inference functions. This takes advantage of cluster parallelism and avoids sequential execution, allowing all your model inferences to run efficiently in parallel across nodes.