The code snippet below takes each group (based on id) and perform recursive joins to build parent-child relations (id1 and id2) within a group. The code produce the correct output, an array in column 'path'.
However, in my real world use-case, this code snippet takes over 30 minutes to run (1300 iterations) and I haven't been able to store the result in a table since that cell just keep running.
Based on the code below, is there something obvious I can do to improve performance? Could perhaps this code benefit from caching? Is it the recursive join that is the bottleneck, or the part when I loop thru 1300 dataframes (one for each group) collected in a list (results_df) and perform UNION.
Anything that could push me in the right direction would be appreciated.
import pyspark.sql.functions as F
from pyspark.sql.functions import isnull, when, col, ntile, count
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
# CREATE TEST DATA FRAME
schema = StructType([
StructField("id", IntegerType(), True),
StructField("level", IntegerType(), True),
StructField("id1", IntegerType(), True),
StructField("id2", IntegerType(), True)
])
# Create a list of data to populate the DataFrame
data = [
(10, 1, 0, 100),
(10, 2, 100, 101),
(10, 3, 101, 102),
(10, 3, 101, 103),
(10, 4, 102, 148),
(10, 4, 102, 149),
(10, 4, 103, 150),
(10, 4, 103, 151),
(20, 1, 0, 200),
(20, 2, 200, 201),
(20, 2, 200, 202),
(20, 2, 200, 203),
(20, 2, 200, 204),
(20, 3, 201, 221),
(20, 3, 201, 222),
(20, 3, 202, 223),
(20, 4, 222, 231),
(20, 5, 231, 299),
(30, 1, 0, 300),
(30, 2, 300, 302),
(30, 3, 302, 303),
(30, 4, 303, 310),
(30, 4, 303, 311),
(30, 4, 303, 312),
(30, 5, 311, 321),
(30, 5, 312, 322),
]
test = spark.createDataFrame(data, schema)
test = test.withColumn("path", F.array("id1"))
# Groupby to find the number of levels for each id.
depth_list = test.groupBy("id").agg(F.max("level").alias("depth")).collect()
def tree_path(group):
current_level = group['depth']
tmp=test.filter(col('id')==group['id'])
original_group = tmp
while current_level > 1:
# Recursive join to get the parent-child relationship on each level.
# Heavy lifting is here - can it be optimized?
joined_df = tmp.alias("child").join(
original_group.alias("parent"),
(F.col("child.id1") == F.col("parent.id2")) & (F.col("child.id1") != F.col("parent.id1")),
"left"
).select(
F.col("child.id"),
F.col("child.level"),
F.col("parent.id1").alias("id1"),
F.col("child.id2"),
# Append the latest parent to path if not null
F.expr("CASE WHEN parent.id1 IS NOT NULL THEN array_union(child.path, array(parent.id1)) ELSE child.path END").alias("path")
)
tmp = joined_df
# Adjust level
current_level -= 1
# Union with result_df
return joined_df
results_dfs=[]
for g in depth_list:
results_dfs.append(tree_path(g))
new_result_df = results_dfs[0]
for df in results_dfs[1:]:
new_result_df = new_result_df.union(df)