cancel
Showing results forย 
Search instead forย 
Did you mean:ย 
Machine Learning
Dive into the world of machine learning on the Databricks platform. Explore discussions on algorithms, model training, deployment, and more. Connect with ML enthusiasts and experts.
cancel
Showing results forย 
Search instead forย 
Did you mean:ย 

How to implement early stop in SparkXGBRegressor with Pipeline?

bbashuk
New Contributor II

Trying to implement an Early Stopping mechanism in SparkXGBRegressor model with Pipeline:

 

 

from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml import Pipeline, PipelineModel
from xgboost.spark import SparkXGBRegressor
from xgboost.callback import EarlyStopping

assembler = VectorAssembler() \
    .setInputCols(relevant_model_cols) \
    .setOutputCol("features") \
    .setHandleInvalid("keep")

early_stop = EarlyStopping(
    rounds=5,
    min_delta=1e-3,
    save_best=True,
    maximize=True,
    data_name='validation_0',
    metric_name="auc",
)

xgboost_regressor = SparkXGBRegressor()
xgboost_regressor.setParams(
    gamma=0.2,
    max_depth=6,
    objective="reg:logistic",       # logistic regression, output probability
    missing=MISSING_VALUE_NUM_DEFAULT,
    num_workers=60,
    subsample=0.5,
    colsample_bytree=0.7,
    learning_rate=0.01,
    random_state=1234,
    reg_alpha=0.35,
    reg_lambda=0.3,
    n_estimators=50,
    eval_metric='auc',
    callbacks=[early_stop]
    )

pipeline = (
    Pipeline()
    .setStages([assembler,
                xgboost_regressor])
)

trained_model = pipeline.fit(train_dataset)

 

 

But, get the error:

Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.

The same with a small dataset.

I also tried to use:

 

 

xgboost_regressor.setParams(
    early_stopping_rounds=10,
    validation_indicator_col='validation_0')

 

 

1 REPLY 1

bbashuk
New Contributor II

Ok, I finally solved it - added a column to the dataset validation_indicator_col='validation_0', and did not pass it the the VectorAssembler:

xgboost_regressor = SparkXGBRegressor()
xgboost_regressor.setParams(
    gamma=0.2,
    max_depth=6,
    objective="reg:logistic",       # logistic regression, output probability
    missing=MISSING_VALUE_NUM_DEFAULT,
    num_workers=60,
    subsample=0.5,
    colsample_bytree=0.7,
    learning_rate=0.01,
    random_state=1234,
    reg_alpha=0.35,
    reg_lambda=0.3,
    n_estimators=600,
    eval_metric='auc',
    early_stopping_rounds=5,
    validation_indicator_col='validation_0',
    maximize=True,
    verbose=True,
    )

 

Connect with Databricks Users in Your Area

Join a Regional User Group to connect with local Databricks users. Events will be happening in your city, and you wonโ€™t want to miss the chance to attend and share knowledge.

If there isnโ€™t a group near you, start one and help create a community that brings people together.

Request a New Group