Optimizing recursive joins on group and UNION-operations.

Henrik_
New Contributor III

The code snippet below takes each group (based on id) and perform recursive joins to build parent-child relations  (id1 and id2) within a group. The code produce the correct output, an array in column 'path'.

However, in my real world use-case, this code snippet takes over 30 minutes to run (1300 iterations) and I haven't been able to store the result in a table since that cell just keep running. 

Based on the code below, is there something obvious I can do to improve performance? Could perhaps this code benefit from caching? Is it the recursive join that is the bottleneck, or the part when I loop thru 1300 dataframes (one for each group) collected in a list (results_df) and perform UNION. 

Anything that could push me in the right direction would be appreciated. 

 

 

import pyspark.sql.functions as F
from pyspark.sql.functions import isnull, when, col, ntile, count
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, IntegerType, StringType

 

 

 

# CREATE TEST DATA FRAME
schema = StructType([
    StructField("id", IntegerType(), True),
    StructField("level", IntegerType(), True),
    StructField("id1", IntegerType(), True),
    StructField("id2", IntegerType(), True)
])

# Create a list of data to populate the DataFrame
data = [
    (10, 1, 0, 100),
    (10, 2, 100, 101),
    (10, 3, 101, 102),
    (10, 3, 101, 103),
    (10, 4, 102, 148),
    (10, 4, 102, 149),
    (10, 4, 103, 150),
    (10, 4, 103, 151),
    (20, 1, 0, 200),
    (20, 2, 200, 201),
    (20, 2, 200, 202),
    (20, 2, 200, 203),
    (20, 2, 200, 204),
    (20, 3, 201, 221),
    (20, 3, 201, 222),
    (20, 3, 202, 223),
    (20, 4, 222, 231),
    (20, 5, 231, 299),
    (30, 1, 0, 300),
    (30, 2, 300, 302),
    (30, 3, 302, 303),
    (30, 4, 303, 310),
    (30, 4, 303, 311),
    (30, 4, 303, 312),
    (30, 5, 311, 321),
    (30, 5, 312, 322),
    ]
test = spark.createDataFrame(data, schema)
test = test.withColumn("path", F.array("id1"))
# Groupby to find the number of levels for each id.
depth_list = test.groupBy("id").agg(F.max("level").alias("depth")).collect()

def tree_path(group):
    
    current_level = group['depth']
    tmp=test.filter(col('id')==group['id'])
    original_group = tmp

    while current_level > 1:
        # Recursive join to get the parent-child relationship on each level.
        # Heavy lifting is here - can it be optimized? 

        joined_df = tmp.alias("child").join(
            original_group.alias("parent"),
            (F.col("child.id1") == F.col("parent.id2")) & (F.col("child.id1") != F.col("parent.id1")),
            "left" 
            ).select(
                F.col("child.id"),
                F.col("child.level"),
                F.col("parent.id1").alias("id1"),
                F.col("child.id2"),
                # Append the latest parent to path if not null
                F.expr("CASE WHEN parent.id1 IS NOT NULL THEN array_union(child.path, array(parent.id1)) ELSE child.path END").alias("path")
            )
            
        tmp = joined_df
        # Adjust level
        current_level -= 1
        # Union with result_df

    return joined_df

results_dfs=[]
for g in depth_list:
    results_dfs.append(tree_path(g))

new_result_df = results_dfs[0]
for df in results_dfs[1:]:
    new_result_df = new_result_df.union(df)