Showing results for 
Search instead for 
Did you mean: 
Technical Blog
Explore in-depth articles, tutorials, and insights on data analytics and machine learning in the Databricks Technical Blog. Stay updated on industry trends, best practices, and advanced techniques.
Showing results for 
Search instead for 
Did you mean: 
Contributor II
Contributor II


MLflow makes it easy to log and deploy machine learning models for use across your organization. However, the standard MLflow interfaces may not provide the necessary functionality to achieve every use case. For example, a simple scikit-learn model logged with MLflow will use the predict function, but doesn’t support returning probabilities using predict_proba or returning multiple class predictions. This blog documents an approach to extending the standard MLflow interfaces to enable the prediction of multiple classes and their probabilities, providing a more comprehensive solution for machine learning practitioners. More generally, this approach can be used to log custom functionality alongside a model. We'll walk through an end to end example below, which you can follow along with by cloning this notebook.

To demonstrate the use of MLflow for predicting the probability of multiple classes, we’ll create a random sensor dataset to set up our example. Next, we’ll train a model to predict the resulting classes based on the sensor data. This could be useful for making predictions about which state an engine is in at a given point in time, which factory station a part should be assembled in, or what fault mode a device might be experiencing. For each of these scenarios, it helps to produce the top N predicted classes and their probabilities in case the top prediction ends up being wrong. By providing the top N predictions, we create fallback options for downstream consumers of the predictions. 

For example, if we predict that a machine should be assembled at a factory station that happens to already be in use, the coordinator can have other station options at the ready. Similarly, if we predict an error code which happens to be incorrect, a technician can validate if the other high probability error codes are more accurate. By the end of our sensor example, we will gain an understanding how MLflow can be used to predict the probability of multiple classes in order to address these sorts of classification problems.

Step 1: Install and Import Libraries

First, we’ll install the dbldatagen library which we’ll use to produce an example dataset and import the libraries that we’ll require throughout the example. In this case we’ll use a GradientBoostingClassifier from scikit-learn, but any model which can be used to produce classification probabilities could be swapped in.


%pip install dbldatagen

import dbldatagen as dg
from pyspark.sql.types import StringType, FloatType
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
import mlflow
from mlflow.models import infer_signature
from pyspark.sql.functions import struct, col


Step 2: Generate example dataset

Next, we’ll use dbldatagen to create a 1000 row dataframe that will be used to train our model, and later to make classification predictions. In the code below, df_features is the resulting Spark dataframe with three sensor features and a result column that we’ll try to classify.


data_spec = (dg.DataGenerator(spark, rows=1000)
           .withColumn("sensor1", FloatType(), minValue=0, maxValue=100, random=True)
           .withColumn("sensor2", FloatType(), minValue=0, maxValue=50, random=True)
           .withColumn("sensor3", FloatType(), minValue=0, maxValue=25, random=True)
           .withColumn("result", StringType(), values=['class1', 'class2', 'class3', 'class4', 'class5',
                                                       'class6', 'class7', 'class8', 'class9', 'class10'], random=True))

df_features =


Step 3: Train a classification model

The code below converts the Spark dataframe to Pandas for training and runs the standard test/train split. We’ll fit the model to our data so that we can use it for predictions. Note that we skip some standard ML best practices like evaluation on the test set so that we can keep this example focused on multi-class predictions in MLflow.


# Create features in Pandas and train the sklearn classifier
features = df_features.toPandas()

X = features[['feature1', 'feature2', 'feature3']]
y = features['result']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

model = GradientBoostingClassifier(random_state=42), y_train)


Step 4: Train a custom MLflow model

When using MLflow to predict the probability of multiple classes, a custom PyFunc model can be created to return the top N classes and their corresponding probabilities. This requires wrapping the initial model with a custom implementation that extracts the required information. The custom model accepts the standard scikit-learn model and the N parameter to indicate how many classes to produce. We’ll use the results of the scikit-learn model’s predict_proba() function and classes_ attribute to pair the predictions to likely classes. Ultimately it returns the predicted classes and their probabilities as we see in the simple example in the last line.


# Create a custom MLflow model that meets our requirements
class ProbabilityModel(mlflow.pyfunc.PythonModel):
   def __init__(self, sklearn_model, n=3):
       self.sklearn_model = sklearn_model
       self.n = n
   def predict(self, context, model_input):
       predictions = model.predict_proba(model_input)
       top_n_indices = np.argsort(-predictions, axis=1)[:, :self.n]
       top_n_classes = model.classes_[top_n_indices]
       top_n_probabilities = predictions[np.arange(len(predictions))[:, None], top_n_indices]
       pred_dict = {}
       for i in range(self.n):
           pred_dict['predicted_class_'+str(i+1)] = top_n_classes[:, i]
           pred_dict['predicted_probability_'+str(i+1)] = top_n_probabilities[:, i]
       return pd.DataFrame(pred_dict)

prob_model = ProbabilityModel(model)
predictions = prob_model.predict('', X_test[:5])


Step 5: Register the model and make predictions

Once we have trained and configured our custom MLflow model, we’ll deploy it to the Unity Catalog model registry to make it available for inference. This involves logging and registering the model, which allows us to manage its lifecycle and versioning. By leveraging the MLflow model user defined function for inference, we can use it to predict the probability of multiple classes for a greater level of flexibility than we get out of the box. In this case, our model returns a struct and we unpack those values into their own columns. These predictions could be run on any Spark dataframe (batch or streaming) from any notebook or pipeline which has permission to access our model. In fact, given the way MLflow separates model training and deployment from the generic inference we see in this example, we could repeatedly retrain and redeploy the model without disrupting the inference process. The downstream consumers of our model don’t need to worry about dependencies or library changes, they only need to understand the expected input and output schemas.


# Log the custom model
signature = infer_signature(sample, predictions)
with mlflow.start_run() as run:
   logged_model = mlflow.pyfunc.log_model("model", python_model=prob_model, input_example=sample, signature=signature)

# Load the model from MLflow into a Spark UDF
model_uri = f"runs:/{}/model"
loaded_model = mlflow.pyfunc.load_model(model_uri)
predict_udf = mlflow.pyfunc.spark_udf(spark, model_uri)

# Apply the model to a Spark DataFrame
predictions_df = (
   df_features.withColumn("predictions", predict_udf(struct(*df_features.columns)))
   .select("*", "predictions.*")



In this blog post, we demonstrated how to create a custom MLflow PyFunc model to predict the probability of multiple classes in Databricks. By extending the standard MLflow interfaces, you can tailor the model's behavior to your specific requirements and improve your machine learning workflows. This approach works for streaming, batch, or real time serving use cases, and can be extended further to include just-in-time featurization or other custom functionality. We invite you to learn more about ML on Databricks with the Big Book of ML.