cancel
Showing results for 
Search instead for 
Did you mean: 
Machine Learning
cancel
Showing results for 
Search instead for 
Did you mean: 

Parallelization in training machine learning models using MLFlow

ianchenmu
New Contributor II

I'm training a ML model (e.g., XGboost) and I have a large combination of 5 hyperparameters, say each parameter has 5 candidates, it will be 5^5 = 3,125 combos.

Now I want to do parallelization for the grid search on all the hyperparameter combos for training a machine learning model to get the best performance of the model.

So how can I achieve this on Databricks, especially using MLFlow? I've been told I can define a function to train and evaluate the model (using mlflow) and defining an array with all of the hyper-parameter combinations, sc.parallelize the array and then mapping the function over.

I have come up with the code for the sc.parallelize the array, like

paras_combo_test =  [(x, y) for x in [50, 100, 150] for y in [0.8,0.9,0.95]]
sc.parallelize(paras_combo_test, 3).glom().collect()

(for simplicit, I'm just using two parameters x, y and there are 9 combos in total and I divided them to 3 partitions.)

How can I map over the function which does the model training with evaluation (probably using mlflow), so that there will be 3 works (each work will train 3 models) in parallel from the partitions of parameter combos I have?

5 REPLIES 5

Anonymous
Not applicable

This blog should be very helpful:

https://www.databricks.com/blog/2021/04/15/how-not-to-tune-your-model-with-hyperopt.html

Here are the docs on xgboost

https://docs.databricks.com/machine-learning/train-model/xgboost.html

A simple rule is never use sc.parallelize.

ianchenmu
New Contributor II

Thanks @Joseph Kambourakis​ ! It seems we could do the distributed XGBoost training using the num_workers regards to how many workers in the cluster. But can we also speed up by setting a parameter utilizing the number of cores in the cluster?

Anonymous
Not applicable

Hubert-Dudek
Esteemed Contributor III

collect() is working on the driver and will not offer any parallelism but rather OOM error.

Anonymous
Not applicable

Hi @Chen Mu​ 

Hope all is well!

Just wanted to check in if you were able to resolve your issue and would you be happy to share the solution or mark an answer as best? Else please let us know if you need more help. 

We'd love to hear from you.

Thanks!

Welcome to Databricks Community: Lets learn, network and celebrate together

Join our fast-growing data practitioner and expert community of 80K+ members, ready to discover, help and collaborate together while making meaningful connections. 

Click here to register and join today! 

Engage in exciting technical discussions, join a group with your peers and meet our Featured Members.