How to save model produce by distributed training?

kng88
New Contributor II

I am trying to save model after distributed training via the following code

import sys
 
from spark_tensorflow_distributor import MirroredStrategyRunner
 
import mlflow.keras
 
mlflow.keras.autolog()
 
mlflow.log_param("learning_rate", 0.001)
 
import tensorflow as tf
 
import time
 
from sklearn.model_selection import train_test_split
 
from sklearn.datasets import load_breast_canc # add er,  because databrick doesn't allow canc.... 
 
 
 
def train():
 
 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
 
 #tf.distribute.experimental.CollectiveCommunication.NCCL
 
 model = None
 
 with strategy.scope():
 
 
 
  data = load_breast_canc()  # add er,  because databrick doesn't allow canc.... 
 
  X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.3)
 
  N, D = X_train.shape # number of observation and variables
 
  from sklearn.preprocessing import StandardScaler
 
  scaler = StandardScaler()
 
  X_train = scaler.fit_transform(X_train)
 
  X_test = scaler.transform(X_test)
 
  model = tf.keras.models.Sequential([
 
   tf.keras.layers.Input(shape=(D,)),
 
   tf.keras.layers.Dense(1, activation='sigmoid') # use sigmoid function for every epochs
 
  ])
 
 
 
  model.compile(optimizer='adam', # use adaptive momentum
 
    loss='binary_crossentropy',
 
    metrics=['accuracy']) 
 
 
 
  # Train the Model
 
  r = model.fit(X_train, y_train, validation_data=(X_test, y_test))
 
  print("Train score:", model.evaluate(X_train, y_train)) # evaluate returns loss and accuracy
 
 
 
 
 
  mlflow.keras.log_model(model, "mymodel")
 
 
 
MirroredStrategyRunner(num_slots=4, use_custom_strategy=True).run(train)

@https://github.com/tensorflow/ecosystem/blob/master/spark/spark-tensorflow-distributor/spark_tensorflow_distributor/mirrored_strategy_runner.py

I have a couple questions

  1. setting num_slots = 4 will cause mlflow to log 4 models , for which each model is not good at predicting the dataset, but. I expect the chief node to log one model that has at least 80% accuracy , is there a way to save only one model or merge the model?
  2. how to save your model without mlflow.log , if I save via dbutil I would get race condition, but it is not clear from the spark distributor which node is the chief node
  3. is every node getting all data instead of partial data?