cancel
Showing results forย 
Search instead forย 
Did you mean:ย 
Machine Learning
Dive into the world of machine learning on the Databricks platform. Explore discussions on algorithms, model training, deployment, and more. Connect with ML enthusiasts and experts.
cancel
Showing results forย 
Search instead forย 
Did you mean:ย 

How to save model produce by distributed training?

kng88
New Contributor II

I am trying to save model after distributed training via the following code

import sys
 
from spark_tensorflow_distributor import MirroredStrategyRunner
 
import mlflow.keras
 
mlflow.keras.autolog()
 
mlflow.log_param("learning_rate", 0.001)
 
import tensorflow as tf
 
import time
 
from sklearn.model_selection import train_test_split
 
from sklearn.datasets import load_breast_canc # add er,  because databrick doesn't allow canc.... 
 
 
 
def train():
 
 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
 
 #tf.distribute.experimental.CollectiveCommunication.NCCL
 
 model = None
 
 with strategy.scope():
 
 
 
  data = load_breast_canc()  # add er,  because databrick doesn't allow canc.... 
 
  X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.3)
 
  N, D = X_train.shape # number of observation and variables
 
  from sklearn.preprocessing import StandardScaler
 
  scaler = StandardScaler()
 
  X_train = scaler.fit_transform(X_train)
 
  X_test = scaler.transform(X_test)
 
  model = tf.keras.models.Sequential([
 
   tf.keras.layers.Input(shape=(D,)),
 
   tf.keras.layers.Dense(1, activation='sigmoid') # use sigmoid function for every epochs
 
  ])
 
 
 
  model.compile(optimizer='adam', # use adaptive momentum
 
    loss='binary_crossentropy',
 
    metrics=['accuracy']) 
 
 
 
  # Train the Model
 
  r = model.fit(X_train, y_train, validation_data=(X_test, y_test))
 
  print("Train score:", model.evaluate(X_train, y_train)) # evaluate returns loss and accuracy
 
 
 
 
 
  mlflow.keras.log_model(model, "mymodel")
 
 
 
MirroredStrategyRunner(num_slots=4, use_custom_strategy=True).run(train)

@https://github.com/tensorflow/ecosystem/blob/master/spark/spark-tensorflow-distributor/spark_tensorflow_distributor/mirrored_strategy_runner.py

I have a couple questions

  1. setting num_slots = 4 will cause mlflow to log 4 models , for which each model is not good at predicting the dataset, but. I expect the chief node to log one model that has at least 80% accuracy , is there a way to save only one model or merge the model?
  2. how to save your model without mlflow.log , if I save via dbutil I would get race condition, but it is not clear from the spark distributor which node is the chief node
  3. is every node getting all data instead of partial data?

6 REPLIES 6

Alexx02
New Contributor II

It is very good that there are now many useful programs that make it easy to use, such as cat et software. I recommend it to everyone.

Frost69
New Contributor II

ModelCheckpoint callback is used in conjunction with training using model. fit() to save a model or weights (in a checkpoint file) at some interval, so the model or weights can be loaded later to continue the training from the state saved.

ACEFlareAccount

kng88
New Contributor II

how does model checkpoint knows who is the chief node?

there should be an api for 1 resulting model from distributed training?

Xiaowei
New Contributor III

Is there any update on the answer? I am curious too.

Is there a merge operation after all the distributed training finished?

Xiaowei
New Contributor III

I guess spark_tensorflow_distributor  is probably obsolete since there is no update since 2020.

Horovod (https://github.com/horovod) seems a better choice on using tensorflow in Databricks with Spark.

Xiaowei
New Contributor III

I think I finally worked this out.

Here is the extra code to save out the model only once and from the 1st node:

context = pyspark.BarrierTaskContext.get()
if context.partitionId() == 0: mlflow.keras.log_model(model, "mymodel")



Connect with Databricks Users in Your Area

Join a Regional User Group to connect with local Databricks users. Events will be happening in your city, and you wonโ€™t want to miss the chance to attend and share knowledge.

If there isnโ€™t a group near you, start one and help create a community that brings people together.

Request a New Group