Showing results for 
Search instead for 
Did you mean: 
Machine Learning
Dive into the world of machine learning on the Databricks platform. Explore discussions on algorithms, model training, deployment, and more. Connect with ML enthusiasts and experts.
Showing results for 
Search instead for 
Did you mean: 

Not able to log xgboost model to mlflow

New Contributor III

I have been trying to log mlflow model but seems to be not working. It logs only the last(which is also the worst run).

#-------------------------------------------------------13.0 ML XGBOost-------------------------------------------------------------------------

from hyperopt import fmin, tpe, Trials, hp
import numpy as np
import mlflow
import mlflow.spark

from import StringIndexer, VectorAssembler
from import Pipeline
from xgboost.spark import SparkXGBRegressor
from import RegressionEvaluator
import numpy as np
from mlflow.models.signature import infer_signature
#vec_assembler = VectorAssembler(inputCols=train_df.columns[1:], outputCol="features")

xgb = SparkXGBRegressor(num_workers=1, label_col="price", missing=0.0)
# pipeline = Pipeline(stages=[vec_assembler, xgb])
pipeline = Pipeline(stages=[ordinal_encoder, vec_assembler, xgb])
regression_evaluator = RegressionEvaluator(predictionCol="prediction", labelCol="price")

def objective_function(params):    
    # set the hyperparameters that we want to tune
    max_depth = params["max_depth"]
    n_estimators = params["n_estimators"]

    with mlflow.start_run(nested=True):
        estimator = pipeline.copy({xgb.max_depth: max_depth, xgb.n_estimators: n_estimators})
        model =

        preds = model.transform(test_df)
        rmse = regression_evaluator.evaluate(preds)
        #r2 = regression_evaluator.setMetricName("r2").evaluate(preds)
        mlflow.log_metric("rmse", rmse)
        #mlflow.log_metric("r2", r2)

    return rmse

search_space = {
    "max_depth" : hp.choice('max_depth', np.arange(12, 15, dtype=int)),
     "n_estimators": hp.choice('n_estimators', np.arange(50, 80, dtype=int))

num_evals = 1
trials = Trials()
best_hyperparam = fmin(fn=objective_function, 

# Retrain model on train & validation dataset and evaluate on test dataset
with mlflow.start_run():
    best_max_depth = best_hyperparam["max_depth"]
    best_n_estimators = best_hyperparam["n_estimators"]
    estimator = pipeline.copy({xgb.max_depth: best_max_depth, xgb.n_estimators: best_n_estimators})
    #combined_df = train_df.union(test_df) # Combine train & validation together

    pipeline_model =
    pred_df = pipeline_model.transform(test_df)
    #signature = infer_signature(test_df, pred_df)
    rmse = regression_evaluator.evaluate(pred_df)
    r2 = regression_evaluator.setMetricName("r2").evaluate(pred_df)

    # Log param and metrics for the final model
    mlflow.log_param("maxdepth", best_max_depth)
    mlflow.log_param("n_estimators", best_n_estimators)
    mlflow.log_metric("rmse", rmse)
    mlflow.log_metric("r2", r2)
    # mlflow.transformers.log_model(pipeline_model,"model",
    mlflow.spark.log_model(pipeline_model ,"model",
    # mlflow.sklearn.log_model(pipeline_model,"model",

Valued Contributor III
Valued Contributor III

Hi @raghagra,

Thank you for posting your question in the Databricks community.

The reason why the code is only logging the last run is because you are using the mlflow.start_run() function inside the objective_function() function. This means that each time you call the objective_function() function, it will start a new run. The mlflow.spark.log_model() function only logs the model for the current run, so the model will only be logged for the last run.

To fix this, you can move the mlflow.start_run() function outside of the objective_function() function. This will ensure that the model is logged for every run.

Please check how it works.

New Contributor III

@Kumaran Still did not work. getting the below error:
2023/07/22 11:30:21 INFO mlflow.spark: Inferring pip requirements by reloading the logged model from the databricks artifact repository, which can be time-consuming. To speed up, explicitly specify the conda_env or pip_requirements when calling log_model(). 2023/07/22 11:31:02 WARNING mlflow.utils.environment: Encountered an unexpected error while inferring pip requirements (model URI: dbfs:/databricks/mlflow-tracking/590967242928602/e3bd64c64535425192a510bd4ee66dec/artifacts/xgb_model/sparkml, flavor: spark), fall back to return ['pyspark==3.4.0']. Set logging level to DEBUG to see the full traceback. /databricks/python/lib/python3.10/site-packages/_distutils_hack/ UserWarning: Setuptools is replacing distutils. warnings.warn("Setuptools is replacing distutils.")

Valued Contributor III
Valued Contributor III

Hi @raghagra,

Can you try the following code instead (please modify according to your need) to log the model:

import mlflow

with mlflow.start_run(experiment_id="1234") as run:
    mlflow.set_tag("status", "started")
    mlflow.log_param("git_hash", "1234")
    mlflow.log_param("env", "stg")
    mlflow.log_param("pipeline_id", "es_s3_to_raw")

    run_id =
    mlflow.log_param("run_id", run_id)

    mlflow.set_tag("run_url", "URL of your model")
    mlflow.log_param("id_in_job", "1726769")
    mlflow.log_param("context.user", "your email id")


New Contributor III

@Kumaran Ran this code, but any specific log that I should be looking for?

Join 100K+ Data Experts: Register Now & Grow with Us!

Excited to expand your horizons with us? Click here to Register and begin your journey to success!

Already a member? Login and join your local regional user group! If there isn’t one near you, fill out this form and we’ll create one for you to join!