cancel
Showing results for 
Search instead for 
Did you mean: 
Data Engineering
Join discussions on data engineering best practices, architectures, and optimization strategies within the Databricks Community. Exchange insights and solutions with fellow data engineers.
cancel
Showing results for 
Search instead for 
Did you mean: 

Optimizing recursive joins on group and UNION-operations.

Henrik_
New Contributor III

The code snippet below takes each group (based on id) and perform recursive joins to build parent-child relations  (id1 and id2) within a group. The code produce the correct output, an array in column 'path'.

However, in my real world use-case, this code snippet takes over 30 minutes to run (1300 iterations) and I haven't been able to store the result in a table since that cell just keep running. 

Based on the code below, is there something obvious I can do to improve performance? Could perhaps this code benefit from caching? Is it the recursive join that is the bottleneck, or the part when I loop thru 1300 dataframes (one for each group) collected in a list (results_df) and perform UNION. 

Anything that could push me in the right direction would be appreciated. 

 

 

import pyspark.sql.functions as F
from pyspark.sql.functions import isnull, when, col, ntile, count
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, IntegerType, StringType

 

 

 

# CREATE TEST DATA FRAME
schema = StructType([
    StructField("id", IntegerType(), True),
    StructField("level", IntegerType(), True),
    StructField("id1", IntegerType(), True),
    StructField("id2", IntegerType(), True)
])

# Create a list of data to populate the DataFrame
data = [
    (10, 1, 0, 100),
    (10, 2, 100, 101),
    (10, 3, 101, 102),
    (10, 3, 101, 103),
    (10, 4, 102, 148),
    (10, 4, 102, 149),
    (10, 4, 103, 150),
    (10, 4, 103, 151),
    (20, 1, 0, 200),
    (20, 2, 200, 201),
    (20, 2, 200, 202),
    (20, 2, 200, 203),
    (20, 2, 200, 204),
    (20, 3, 201, 221),
    (20, 3, 201, 222),
    (20, 3, 202, 223),
    (20, 4, 222, 231),
    (20, 5, 231, 299),
    (30, 1, 0, 300),
    (30, 2, 300, 302),
    (30, 3, 302, 303),
    (30, 4, 303, 310),
    (30, 4, 303, 311),
    (30, 4, 303, 312),
    (30, 5, 311, 321),
    (30, 5, 312, 322),
    ]
test = spark.createDataFrame(data, schema)
test = test.withColumn("path", F.array("id1"))
# Groupby to find the number of levels for each id.
depth_list = test.groupBy("id").agg(F.max("level").alias("depth")).collect()

def tree_path(group):
    
    current_level = group['depth']
    tmp=test.filter(col('id')==group['id'])
    original_group = tmp

    while current_level > 1:
        # Recursive join to get the parent-child relationship on each level.
        # Heavy lifting is here - can it be optimized? 

        joined_df = tmp.alias("child").join(
            original_group.alias("parent"),
            (F.col("child.id1") == F.col("parent.id2")) & (F.col("child.id1") != F.col("parent.id1")),
            "left" 
            ).select(
                F.col("child.id"),
                F.col("child.level"),
                F.col("parent.id1").alias("id1"),
                F.col("child.id2"),
                # Append the latest parent to path if not null
                F.expr("CASE WHEN parent.id1 IS NOT NULL THEN array_union(child.path, array(parent.id1)) ELSE child.path END").alias("path")
            )
            
        tmp = joined_df
        # Adjust level
        current_level -= 1
        # Union with result_df

    return joined_df

results_dfs=[]
for g in depth_list:
    results_dfs.append(tree_path(g))

new_result_df = results_dfs[0]
for df in results_dfs[1:]:
    new_result_df = new_result_df.union(df)

 

 

1 ACCEPTED SOLUTION

Accepted Solutions

-werners-
Esteemed Contributor III

The recursive join is definitely a performance killer.  It will blow up the query plan.
So I would advice against using it.
Alternatives?  Well, a fixed amount of joins for example, if that is an option of course.
Using a graph algorithm is also an option.
It is important that you figure out what kind of graph you have, or even multiple graphs (is it directed, are all edges connected, acyclic or not, do you want to visit all edges and vertices etc).
Once you have that, you have the choice of either:
- use graphframes/graphx in spark (not easy to use!)
- use pure python with some graph processing package (only an option if the amount of data is reasonable)
- use some kind of graph software outside of databricks

IIRC there was some talk of introducing Cypher (of Neo4j) into spark or databricks but that apparently never happened.

View solution in original post

1 REPLY 1

-werners-
Esteemed Contributor III

The recursive join is definitely a performance killer.  It will blow up the query plan.
So I would advice against using it.
Alternatives?  Well, a fixed amount of joins for example, if that is an option of course.
Using a graph algorithm is also an option.
It is important that you figure out what kind of graph you have, or even multiple graphs (is it directed, are all edges connected, acyclic or not, do you want to visit all edges and vertices etc).
Once you have that, you have the choice of either:
- use graphframes/graphx in spark (not easy to use!)
- use pure python with some graph processing package (only an option if the amount of data is reasonable)
- use some kind of graph software outside of databricks

IIRC there was some talk of introducing Cypher (of Neo4j) into spark or databricks but that apparently never happened.

Connect with Databricks Users in Your Area

Join a Regional User Group to connect with local Databricks users. Events will be happening in your city, and you won’t want to miss the chance to attend and share knowledge.

If there isn’t a group near you, start one and help create a community that brings people together.

Request a New Group