- Mark as New
- Bookmark
- Subscribe
- Mute
- Subscribe to RSS Feed
- Permalink
- Report Inappropriate Content
09-05-2024 05:01 AM
The code snippet below takes each group (based on id) and perform recursive joins to build parent-child relations (id1 and id2) within a group. The code produce the correct output, an array in column 'path'.
However, in my real world use-case, this code snippet takes over 30 minutes to run (1300 iterations) and I haven't been able to store the result in a table since that cell just keep running.
Based on the code below, is there something obvious I can do to improve performance? Could perhaps this code benefit from caching? Is it the recursive join that is the bottleneck, or the part when I loop thru 1300 dataframes (one for each group) collected in a list (results_df) and perform UNION.
Anything that could push me in the right direction would be appreciated.
import pyspark.sql.functions as F
from pyspark.sql.functions import isnull, when, col, ntile, count
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
# CREATE TEST DATA FRAME
schema = StructType([
StructField("id", IntegerType(), True),
StructField("level", IntegerType(), True),
StructField("id1", IntegerType(), True),
StructField("id2", IntegerType(), True)
])
# Create a list of data to populate the DataFrame
data = [
(10, 1, 0, 100),
(10, 2, 100, 101),
(10, 3, 101, 102),
(10, 3, 101, 103),
(10, 4, 102, 148),
(10, 4, 102, 149),
(10, 4, 103, 150),
(10, 4, 103, 151),
(20, 1, 0, 200),
(20, 2, 200, 201),
(20, 2, 200, 202),
(20, 2, 200, 203),
(20, 2, 200, 204),
(20, 3, 201, 221),
(20, 3, 201, 222),
(20, 3, 202, 223),
(20, 4, 222, 231),
(20, 5, 231, 299),
(30, 1, 0, 300),
(30, 2, 300, 302),
(30, 3, 302, 303),
(30, 4, 303, 310),
(30, 4, 303, 311),
(30, 4, 303, 312),
(30, 5, 311, 321),
(30, 5, 312, 322),
]
test = spark.createDataFrame(data, schema)
test = test.withColumn("path", F.array("id1"))
# Groupby to find the number of levels for each id.
depth_list = test.groupBy("id").agg(F.max("level").alias("depth")).collect()
def tree_path(group):
current_level = group['depth']
tmp=test.filter(col('id')==group['id'])
original_group = tmp
while current_level > 1:
# Recursive join to get the parent-child relationship on each level.
# Heavy lifting is here - can it be optimized?
joined_df = tmp.alias("child").join(
original_group.alias("parent"),
(F.col("child.id1") == F.col("parent.id2")) & (F.col("child.id1") != F.col("parent.id1")),
"left"
).select(
F.col("child.id"),
F.col("child.level"),
F.col("parent.id1").alias("id1"),
F.col("child.id2"),
# Append the latest parent to path if not null
F.expr("CASE WHEN parent.id1 IS NOT NULL THEN array_union(child.path, array(parent.id1)) ELSE child.path END").alias("path")
)
tmp = joined_df
# Adjust level
current_level -= 1
# Union with result_df
return joined_df
results_dfs=[]
for g in depth_list:
results_dfs.append(tree_path(g))
new_result_df = results_dfs[0]
for df in results_dfs[1:]:
new_result_df = new_result_df.union(df)
- Labels:
-
Spark
Accepted Solutions
- Mark as New
- Bookmark
- Subscribe
- Mute
- Subscribe to RSS Feed
- Permalink
- Report Inappropriate Content
09-05-2024 06:07 AM
The recursive join is definitely a performance killer. It will blow up the query plan.
So I would advice against using it.
Alternatives? Well, a fixed amount of joins for example, if that is an option of course.
Using a graph algorithm is also an option.
It is important that you figure out what kind of graph you have, or even multiple graphs (is it directed, are all edges connected, acyclic or not, do you want to visit all edges and vertices etc).
Once you have that, you have the choice of either:
- use graphframes/graphx in spark (not easy to use!)
- use pure python with some graph processing package (only an option if the amount of data is reasonable)
- use some kind of graph software outside of databricks
IIRC there was some talk of introducing Cypher (of Neo4j) into spark or databricks but that apparently never happened.
- Mark as New
- Bookmark
- Subscribe
- Mute
- Subscribe to RSS Feed
- Permalink
- Report Inappropriate Content
09-05-2024 06:07 AM
The recursive join is definitely a performance killer. It will blow up the query plan.
So I would advice against using it.
Alternatives? Well, a fixed amount of joins for example, if that is an option of course.
Using a graph algorithm is also an option.
It is important that you figure out what kind of graph you have, or even multiple graphs (is it directed, are all edges connected, acyclic or not, do you want to visit all edges and vertices etc).
Once you have that, you have the choice of either:
- use graphframes/graphx in spark (not easy to use!)
- use pure python with some graph processing package (only an option if the amount of data is reasonable)
- use some kind of graph software outside of databricks
IIRC there was some talk of introducing Cypher (of Neo4j) into spark or databricks but that apparently never happened.