Hello!
I'm using databricks-connector to launch spark jobs using python.
I've validated that the python version (3.8.10) and runtime version (8.1) are supported by the installed databricks-connect (8.1.10).
Everytime a mapPartitions/foreachPartition action is created this results in two spark jobs executing, one after the other, duplicating every stage/step that happened before it.
An example code follows:
#!/usr/bin/env python
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, LongType
schema = StructType([
StructField('key', LongType(), True),
StructField('value', StringType(), True)
])
spark = SparkSession.builder.appName('test').getOrCreate()
data = spark.read.schema(schema) \
.option('header', 'true') \
.csv('s3://path/to.csv')
def fun(rows):
print(f"Got a partition with {len(list(rows))} rows")
# these only trigger one job
# data.collect()
# data.count()
# this triggers two!
data.foreachPartition(fun)
This executes two jobs (which is fast in this example but not in real world code!):
The first job, which is the one that I'm not sure why it spawns:
org.apache.spark.rdd.RDD.foreach(RDD.scala:1015)
com.databricks.service.RemoteServiceExec.doExecute(RemoteServiceExec.scala:244)
org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:196)
org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:240)
org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)
org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:236)
org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:192)
org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:163)
org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:162)
org.apache.spark.sql.Dataset.javaToPython(Dataset.scala:3569)
sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
java.lang.reflect.Method.invoke(Method.java:498)
py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
py4j.Gateway.invoke(Gateway.java:295)
py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
py4j.commands.CallCommand.execute(CallCommand.java:79)
py4j.GatewayConnection.run(GatewayConnection.java:251)
And then the actual job:
org.apache.spark.rdd.RDD.collect(RDD.scala:1034)
org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:260)
org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)
sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
java.lang.reflect.Method.invoke(Method.java:498)
py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
py4j.Gateway.invoke(Gateway.java:295)
py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
py4j.commands.CallCommand.execute(CallCommand.java:79)
py4j.GatewayConnection.run(GatewayConnection.java:251)
java.lang.Thread.run(Thread.java:748)
Any idea why this happens and how I can prevent the first job to run and only run the actual code?
I've confirmed that in the first pass, none of the code in the foreachPartitions runs.
Using .cache() is not recommended for real world scenarios because the datasets are large and would take even longer to persist than to execute the job again (possibly failing on disk availability).
One thing this shows is that it looks related to databricks' RemoteServiceExec code. Maybe its unknowingly causing the dataset/rdds to be materialized?
Anyone can help?
Thanks