How to save model produce by distributed training?

New Contributor II

I am trying to save model after distributed training via the following code

import sys
from spark_tensorflow_distributor import MirroredStrategyRunner
import mlflow.keras
mlflow.log_param("learning_rate", 0.001)
import tensorflow as tf
import time
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_canc # add er,  because databrick doesn't allow canc.... 
def train():
 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
 model = None
 with strategy.scope():
  data = load_breast_canc()  # add er,  because databrick doesn't allow canc.... 
  X_train, X_test, y_train, y_test = train_test_split(,, test_size=0.3)
  N, D = X_train.shape # number of observation and variables
  from sklearn.preprocessing import StandardScaler
  scaler = StandardScaler()
  X_train = scaler.fit_transform(X_train)
  X_test = scaler.transform(X_test)
  model = tf.keras.models.Sequential([
   tf.keras.layers.Dense(1, activation='sigmoid') # use sigmoid function for every epochs
  model.compile(optimizer='adam', # use adaptive momentum
  # Train the Model
  r =, y_train, validation_data=(X_test, y_test))
  print("Train score:", model.evaluate(X_train, y_train)) # evaluate returns loss and accuracy
  mlflow.keras.log_model(model, "mymodel")
MirroredStrategyRunner(num_slots=4, use_custom_strategy=True).run(train)


I have a couple questions

  1. setting num_slots = 4 will cause mlflow to log 4 models , for which each model is not good at predicting the dataset, but. I expect the chief node to log one model that has at least 80% accuracy , is there a way to save only one model or merge the model?
  2. how to save your model without mlflow.log , if I save via dbutil I would get race condition, but it is not clear from the spark distributor which node is the chief node
  3. is every node getting all data instead of partial data?


New Contributor II

New Contributor II

ModelCheckpoint callback is used in conjunction with training using model. fit() to save a model or weights (in a checkpoint file) at some interval, so the model or weights can be loaded later to continue the training from the state saved.


New Contributor II

how does model checkpoint knows who is the chief node?

there should be an api for 1 resulting model from distributed training?

New Contributor III

Is there any update on the answer? I am curious too.

Is there a merge operation after all the distributed training finished?

New Contributor III

I guess spark_tensorflow_distributor  is probably obsolete since there is no update since 2020.

Horovod ( seems a better choice on using tensorflow in Databricks with Spark.

New Contributor III

I think I finally worked this out.

Here is the extra code to save out the model only once and from the 1st node:

context = pyspark.BarrierTaskContext.get()
if context.partitionId() == 0: mlflow.keras.log_model(model, "mymodel")

