cancel
Showing results for 
Search instead for 
Did you mean: 
Community Platform Discussions
Connect with fellow community members to discuss general topics related to the Databricks platform, industry trends, and best practices. Share experiences, ask questions, and foster collaboration within the community.
cancel
Showing results for 
Search instead for 
Did you mean: 

Help Needed: Executor Lost Error in Multi-Node Distributed Training with PyTorch

SB93
New Contributor II

Hi everyone,

I'm currently working on distributed training of a PyTorch model, following the example provided here. The training runs perfectly on a single node with a single GPU. However, when I attempt multi-node training using the following configuration:

TorchDistributor(num_processes=4, local_mode=False, use_gpu=True).run(...)

I encounter an "ExecutorLostFailure". This error typically occurs after about +-1 hour of training, although in the best case, it appeared after 8 hours.

I have followed all the suggested solutions, including:

  • Including all import statements (e.g., import torch) both at the top of the training function and inside any other user-defined functions called in the training method.
  • Setting the environment variable: os.environ["NCCL_SOCKET_IFNAME"] = "eth0"

Despite these efforts, the error persists.

Has anyone experienced and resolved this issue? Any insights or suggestions would be greatly appreciated. I'm currently stuck and could really use some help!

Thank you in advance! Below is the stack trace for reference:

Py4JJavaError: An error occurred while calling o515.collectToPython.
: org.apache.spark.SparkException: Job aborted due to stage failure: Could not recover from a failed barrier ResultStage. Most recent failure reason: Stage failed because barrier task ResultTask(0, 2) finished unsuccessfully.
ExecutorLostFailure (executor 1 exited caused by one of the running tasks) Reason: spot instance preemption, spot instance kill
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$failJobAndIndependentStages$1(DAGScheduler.scala:4362)
	at scala.Option.getOrElse(Option.scala:189)
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:4360)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:4274)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:4261)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:4261)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskCompletion(DAGScheduler.scala:3645)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:4620)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.liftedTree1$1(DAGScheduler.scala:4524)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:4523)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:4509)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:55)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$runJob$1(DAGScheduler.scala:1486)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1472)
	at org.apache.spark.SparkContext.runJobInternal(SparkContext.scala:3217)
	at org.apache.spark.sql.execution.collect.Collector.$anonfun$runSparkJobs$1(Collector.scala:296)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94)
	at org.apache.spark.sql.execution.collect.Collector.runSparkJobs(Collector.scala:292)
	at org.apache.spark.sql.execution.collect.Collector.$anonfun$collect$1(Collector.scala:377)
	at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94)
	at org.apache.spark.sql.execution.collect.Collector.collect(Collector.scala:374)
	at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:116)
	at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:124)
	at org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:94)
	at org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:90)
	at org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:78)
	at org.apache.spark.sql.execution.qrc.ResultCacheManager.$anonfun$computeResult$1(ResultCacheManager.scala:588)
	at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94)
	at org.apache.spark.sql.execution.qrc.ResultCacheManager.collectResult$1(ResultCacheManager.scala:579)
	at org.apache.spark.sql.execution.qrc.ResultCacheManager.computeResult(ResultCacheManager.scala:596)
	at org.apache.spark.sql.execution.qrc.ResultCacheManager.$anonfun$getOrComputeResultInternal$1(ResultCacheManager.scala:433)
	at scala.Option.getOrElse(Option.scala:189)
	at org.apache.spark.sql.execution.qrc.ResultCacheManager.getOrComputeResultInternal(ResultCacheManager.scala:432)
	at org.apache.spark.sql.execution.qrc.ResultCacheManager.getOrComputeResult(ResultCacheManager.scala:351)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeCollectResult$1(SparkPlan.scala:590)
	at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94)
	at org.apache.spark.sql.execution.SparkPlan.executeCollectResult(SparkPlan.scala:587)
	at org.apache.spark.sql.Dataset.$anonfun$collectToPython$1(Dataset.scala:4701)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$3(Dataset.scala:4974)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:1245)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4972)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$10(SQLExecution.scala:470)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:756)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$1(SQLExecution.scala:340)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:1426)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId0(SQLExecution.scala:206)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:693)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:4972)
	at org.apache.spark.sql.Dataset.collectToPython(Dataset.scala:4699)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:569)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:397)
	at py4j.Gateway.invoke(Gateway.java:306)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:199)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:119)
	at java.base/java.lang.Thread.run(Thread.java:840)
File <command-3554622095355509>, line 4
      1 from pyspark.ml.torch.distributor import TorchDistributor
      3 # output = TorchDistributor(num_processes=4, local_mode=True, use_gpu=True).run(main_fn, args) #Single node mutiple GPUs
----> 4 output_dist = TorchDistributor(num_processes=4, local_mode=False, use_gpu=True).run(main_fn, args)
File /databricks/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/protocol.py:326, in get_return_value(answer, gateway_client, target_id, name)
    324 value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
    325 if answer[1] == REFERENCE_TYPE:
--> 326     raise Py4JJavaError(
    327         "An error occurred while calling {0}{1}{2}.\n".
    328         format(target_id, ".", name), value)
    329 else:
    330     raise Py4JError(
    331         "An error occurred while calling {0}{1}{2}. Trace:\n{3}\n".
    332         format(target_id, ".", name, value))
1 REPLY 1

cgrant
Databricks Employee
Databricks Employee

We do not recommend using spot instances with distributed ML training workloads that use barrier mode, like TorchDistributor as these workloads are extremely sensitive to executor loss. Please disable spot/pre-emption and try again.

Connect with Databricks Users in Your Area

Join a Regional User Group to connect with local Databricks users. Events will be happening in your city, and you won’t want to miss the chance to attend and share knowledge.

If there isn’t a group near you, start one and help create a community that brings people together.

Request a New Group