Showing results for 
Search instead for 
Did you mean: 
Machine Learning
Dive into the world of machine learning on the Databricks platform. Explore discussions on algorithms, model training, deployment, and more. Connect with ML enthusiasts and experts.
Showing results for 
Search instead for 
Did you mean: 

How to save model produce by distributed training?

New Contributor II

I am trying to save model after distributed training via the following code

import sys
from spark_tensorflow_distributor import MirroredStrategyRunner
import mlflow.keras
mlflow.log_param("learning_rate", 0.001)
import tensorflow as tf
import time
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_canc # add er,  because databrick doesn't allow canc.... 
def train():
 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
 model = None
 with strategy.scope():
  data = load_breast_canc()  # add er,  because databrick doesn't allow canc.... 
  X_train, X_test, y_train, y_test = train_test_split(,, test_size=0.3)
  N, D = X_train.shape # number of observation and variables
  from sklearn.preprocessing import StandardScaler
  scaler = StandardScaler()
  X_train = scaler.fit_transform(X_train)
  X_test = scaler.transform(X_test)
  model = tf.keras.models.Sequential([
   tf.keras.layers.Dense(1, activation='sigmoid') # use sigmoid function for every epochs
  model.compile(optimizer='adam', # use adaptive momentum
  # Train the Model
  r =, y_train, validation_data=(X_test, y_test))
  print("Train score:", model.evaluate(X_train, y_train)) # evaluate returns loss and accuracy
  mlflow.keras.log_model(model, "mymodel")
MirroredStrategyRunner(num_slots=4, use_custom_strategy=True).run(train)


I have a couple questions

  1. setting num_slots = 4 will cause mlflow to log 4 models , for which each model is not good at predicting the dataset, but. I expect the chief node to log one model that has at least 80% accuracy , is there a way to save only one model or merge the model?
  2. how to save your model without mlflow.log , if I save via dbutil I would get race condition, but it is not clear from the spark distributor which node is the chief node
  3. is every node getting all data instead of partial data?


New Contributor II

It is very good that there are now many useful programs that make it easy to use, such as cat et software. I recommend it to everyone.

New Contributor II

ModelCheckpoint callback is used in conjunction with training using model. fit() to save a model or weights (in a checkpoint file) at some interval, so the model or weights can be loaded later to continue the training from the state saved.


New Contributor II

how does model checkpoint knows who is the chief node?

there should be an api for 1 resulting model from distributed training?

New Contributor III

Is there any update on the answer? I am curious too.

Is there a merge operation after all the distributed training finished?

New Contributor III

I guess spark_tensorflow_distributor  is probably obsolete since there is no update since 2020.

Horovod ( seems a better choice on using tensorflow in Databricks with Spark.

New Contributor III

I think I finally worked this out.

Here is the extra code to save out the model only once and from the 1st node:

context = pyspark.BarrierTaskContext.get()
if context.partitionId() == 0: mlflow.keras.log_model(model, "mymodel")

Join 100K+ Data Experts: Register Now & Grow with Us!

Excited to expand your horizons with us? Click here to Register and begin your journey to success!

Already a member? Login and join your local regional user group! If there isn’t one near you, fill out this form and we’ll create one for you to join!