cancel
Showing results for 
Search instead for 
Did you mean: 
Data Engineering
cancel
Showing results for 
Search instead for 
Did you mean: 

MLFlow Spark UDF Error

coltonflowers
New Contributor III

After trying to run 

spark_udf = mlflow.pyfunc.spark_udf(spark, model_uri=logged_model,env_manager="virtualenv")

We get the following error:

org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 145.0 failed 4 times, most recent failure: Lost task 0.3 in stage 145.0 (TID 101) (ip-192-168-0-127.ec2.internal executor driver): org.apache.spark.api.python.PythonException: 'requests.exceptions.ConnectionError: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))'. Full traceback below:
Traceback (most recent call last):
  File "/databricks/python/lib/python3.10/site-packages/mlflow/pyfunc/__init__.py", line 1540, in udf
    os.kill(scoring_server_proc.pid, signal.SIGTERM)
  File "/databricks/python/lib/python3.10/site-packages/mlflow/pyfunc/__init__.py", line 1317, in _predict_row_batch
    result = predict_fn(pdf, params)
  File "/databricks/python/lib/python3.10/site-packages/mlflow/pyfunc/__init__.py", line 1507, in batch_predict_fn
    return client.invoke(pdf, params=params).get_predictions()
  File "/databricks/python/lib/python3.10/site-packages/mlflow/pyfunc/scoring_server/client.py", line 83, in invoke
    response = requests.post(
  File "/databricks/python/lib/python3.10/site-packages/requests/api.py", line 115, in post
    return request("post", url, data=data, json=json, **kwargs)
  File "/databricks/python/lib/python3.10/site-packages/requests/api.py", line 59, in request
    return session.request(method=method, url=url, **kwargs)
  File "/databricks/python/lib/python3.10/site-packages/requests/sessions.py", line 587, in request
    resp = self.send(prep, **send_kwargs)
  File "/databricks/python/lib/python3.10/site-packages/requests/sessions.py", line 701, in send
    r = adapter.send(request, **kwargs)
  File "/databricks/python/lib/python3.10/site-packages/requests/adapters.py", line 547, in send
    raise ConnectionError(err, request=request)
requests.exceptions.ConnectionError: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:614)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:117)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:567)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage4.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.$anonfun$encodeUnsafeRows$5(UnsafeRowBatchUtils.scala:88)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
	at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.$anonfun$encodeUnsafeRows$3(UnsafeRowBatchUtils.scala:88)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
	at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.$anonfun$encodeUnsafeRows$1(UnsafeRowBatchUtils.scala:68)
	at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
	at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:62)
	at org.apache.spark.sql.execution.collect.Collector.$anonfun$processFunc$2(Collector.scala:214)
	at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:82)
	at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
	at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:82)
	at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:62)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:196)
	at org.apache.spark.scheduler.Task.doRunTask(Task.scala:181)
	at org.apache.spark.scheduler.Task.$anonfun$run$5(Task.scala:146)
	at com.databricks.unity.UCSEphemeralState$Handle.runWith(UCSEphemeralState.scala:41)
	at com.databricks.unity.HandleImpl.runWith(UCSHandle.scala:99)
	at com.databricks.unity.HandleImpl.$anonfun$runWithAndClose$1(UCSHandle.scala:104)
	at scala.util.Using$.resource(Using.scala:269)
	at com.databricks.unity.HandleImpl.runWithAndClose(UCSHandle.scala:103)
	at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:146)
	at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
	at org.apache.spark.scheduler.Task.run(Task.scala:99)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$8(Executor.scala:930)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:102)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:933)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:825)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:750)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:3588)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:3520)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:3509)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:3509)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1516)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1516)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1516)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3834)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3746)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3734)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:51)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$runJob$1(DAGScheduler.scala:1240)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1228)
	at org.apache.spark.SparkContext.runJobInternal(SparkContext.scala:2974)
	at org.apache.spark.sql.execution.collect.Collector.$anonfun$runSparkJobs$1(Collector.scala:355)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94)
	at org.apache.spark.sql.execution.collect.Collector.runSparkJobs(Collector.scala:299)
	at org.apache.spark.sql.execution.collect.Collector.$anonfun$collect$1(Collector.scala:384)
	at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94)
	at org.apache.spark.sql.execution.collect.Collector.collect(Collector.scala:381)
	at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:122)
	at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:131)
	at org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:94)
	at org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:90)
	at org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:78)
	at org.apache.spark.sql.execution.qrc.ResultCacheManager.$anonfun$computeResult$1(ResultCacheManager.scala:506)
	at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94)
	at org.apache.spark.sql.execution.qrc.ResultCacheManager.collectResult$1(ResultCacheManager.scala:500)
	at org.apache.spark.sql.execution.qrc.ResultCacheManager.$anonfun$computeResult$2(ResultCacheManager.scala:515)
	at org.apache.spark.sql.execution.adaptive.ResultQueryStageExec.$anonfun$doMaterialize$1(QueryStageExec.scala:528)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:1123)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$5(SQLExecution.scala:623)
	at com.databricks.util.LexicalThreadLocal$Handle.runWith(LexicalThreadLocal.scala:63)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$4(SQLExecution.scala:623)
	at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$3(SQLExecution.scala:622)
	at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$2(SQLExecution.scala:621)
	at org.apache.spark.sql.execution.SQLExecution$.withOptimisticTransaction(SQLExecution.scala:642)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$1(SQLExecution.scala:620)
	at java.util.concurrent.CompletableFuture$AsyncSupply.run(CompletableFuture.java:1604)
	at org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable.$anonfun$run$1(SparkThreadLocalForwardingThreadPoolExecutor.scala:118)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at com.databricks.spark.util.IdentityClaim$.withClaim(IdentityClaim.scala:48)
	at org.apache.spark.util.threads.SparkThreadLocalCapturingHelper.$anonfun$runWithCaptured$4(SparkThreadLocalForwardingThreadPoolExecutor.scala:81)
	at com.databricks.unity.UCSEphemeralState$Handle.runWith(UCSEphemeralState.scala:41)
	at org.apache.spark.util.threads.SparkThreadLocalCapturingHelper.runWithCaptured(SparkThreadLocalForwardingThreadPoolExecutor.scala:80)
	at org.apache.spark.util.threads.SparkThreadLocalCapturingHelper.runWithCaptured$(SparkThreadLocalForwardingThreadPoolExecutor.scala:66)
	at org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable.runWithCaptured(SparkThreadLocalForwardingThreadPoolExecutor.scala:115)
	at org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable.run(SparkThreadLocalForwardingThreadPoolExecutor.scala:118)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:750)
Caused by: org.apache.spark.api.python.PythonException: 'requests.exceptions.ConnectionError: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))'. Full traceback below:
Traceback (most recent call last):
  File "/databricks/python/lib/python3.10/site-packages/mlflow/pyfunc/__init__.py", line 1540, in udf
    os.kill(scoring_server_proc.pid, signal.SIGTERM)
  File "/databricks/python/lib/python3.10/site-packages/mlflow/pyfunc/__init__.py", line 1317, in _predict_row_batch
    result = predict_fn(pdf, params)
  File "/databricks/python/lib/python3.10/site-packages/mlflow/pyfunc/__init__.py", line 1507, in batch_predict_fn
    return client.invoke(pdf, params=params).get_predictions()
  File "/databricks/python/lib/python3.10/site-packages/mlflow/pyfunc/scoring_server/client.py", line 83, in invoke
    response = requests.post(
  File "/databricks/python/lib/python3.10/site-packages/requests/api.py", line 115, in post
    return request("post", url, data=data, json=json, **kwargs)
  File "/databricks/python/lib/python3.10/site-packages/requests/api.py", line 59, in request
    return session.request(method=method, url=url, **kwargs)
  File "/databricks/python/lib/python3.10/site-packages/requests/sessions.py", line 587, in request
    resp = self.send(prep, **send_kwargs)
  File "/databricks/python/lib/python3.10/site-packages/requests/sessions.py", line 701, in send
    r = adapter.send(request, **kwargs)
  File "/databricks/python/lib/python3.10/site-packages/requests/adapters.py", line 547, in send
    raise ConnectionError(err, request=request)

 We are using MLFlow-skinny=2.7.1, on the 14.1 ML Runtime on a G4 instance

1 ACCEPTED SOLUTION

Accepted Solutions

Kaniz
Community Manager
Community Manager

Hi @coltonflowers , The error you’re encountering seems to be related to a connection issue.

Let’s explore some potential solutions:

  1. Check Network Connectivity:

    • Ensure that the machine running your Spark job has proper network connectivity. Verify that it can reach the necessary endpoints (e.g., the MLflow server).
    • Confirm that there are no firewall rules or network restrictions preventing the connection.
  2. Retry Mechanism:

    • Sometimes transient network issues can cause connection failures. Consider implementing a retry mechanism in your code to handle such scenarios.
    • You can use libraries like tenacity or implement custom retry logic.
  3. Increase Timeout Settings:

    • Adjust the timeout settings for the HTTP requests made by MLflow. If the default timeout is too short, it might lead to connection errors.
    • You can set a longer timeout using the requests library or any other HTTP client you’re using.
  4. Check MLflow Server Status:

    • Ensure that the MLflow server is up and running. If it’s hosted remotely, verify its availability.
    • Check the server logs for any errors or issues.
  5. Inspect the Model Signature:

    • The error message mentions an unsupported data type: struct<type:tinyint,size:int,indices:array<int>,values:array<double>>.
    • This suggests that the model signature (input and output schema) might not match the actual data being passed.
    • Review the model signature and ensure it aligns with the data you’re providing.
  6. Use mlflow.spark.load_model Instead:

    • Instead of using mlflow.pyfunc.spark_udf, try loading the model using mlflow.spark.load_model.
    • This method might handle the model signature more effectively.

Here’s an example of how to load the model using mlflow.spark.load_model:

import mlflow.spark

model_path = 'runs:/e905f5759d434a131bbe1e54a2b/best-model'
loaded_model = mlflow.spark.load_model(model_path)

# Predict on a Spark DataFrame
result_df = df.withColumn('predictions', loaded_model(*columns)).collect()

Remember to adjust the model_path according to your specific experiment run ID and model location. If the issue persists, consider examining the model signature and ensuring compatibility with your d...

View solution in original post

1 REPLY 1

Kaniz
Community Manager
Community Manager

Hi @coltonflowers , The error you’re encountering seems to be related to a connection issue.

Let’s explore some potential solutions:

  1. Check Network Connectivity:

    • Ensure that the machine running your Spark job has proper network connectivity. Verify that it can reach the necessary endpoints (e.g., the MLflow server).
    • Confirm that there are no firewall rules or network restrictions preventing the connection.
  2. Retry Mechanism:

    • Sometimes transient network issues can cause connection failures. Consider implementing a retry mechanism in your code to handle such scenarios.
    • You can use libraries like tenacity or implement custom retry logic.
  3. Increase Timeout Settings:

    • Adjust the timeout settings for the HTTP requests made by MLflow. If the default timeout is too short, it might lead to connection errors.
    • You can set a longer timeout using the requests library or any other HTTP client you’re using.
  4. Check MLflow Server Status:

    • Ensure that the MLflow server is up and running. If it’s hosted remotely, verify its availability.
    • Check the server logs for any errors or issues.
  5. Inspect the Model Signature:

    • The error message mentions an unsupported data type: struct<type:tinyint,size:int,indices:array<int>,values:array<double>>.
    • This suggests that the model signature (input and output schema) might not match the actual data being passed.
    • Review the model signature and ensure it aligns with the data you’re providing.
  6. Use mlflow.spark.load_model Instead:

    • Instead of using mlflow.pyfunc.spark_udf, try loading the model using mlflow.spark.load_model.
    • This method might handle the model signature more effectively.

Here’s an example of how to load the model using mlflow.spark.load_model:

import mlflow.spark

model_path = 'runs:/e905f5759d434a131bbe1e54a2b/best-model'
loaded_model = mlflow.spark.load_model(model_path)

# Predict on a Spark DataFrame
result_df = df.withColumn('predictions', loaded_model(*columns)).collect()

Remember to adjust the model_path according to your specific experiment run ID and model location. If the issue persists, consider examining the model signature and ensuring compatibility with your d...