Slow stream static join in Spark Structured Streaming

EDDatabricks
Databricks Partner

Situation

Records are streamed from an input Delta table via a Spark Structured Streaming job. The streaming job performs the following.

  1. Read from input Delta table (readStream)
  2. Static join on small JSON
  3. Static join on big Delta table
  4. Write to three Delta tables using foreachbatch logic

Problem

Step 3 is extremely slow. It takes more than 15 minutes to process a single batch of data using a job compute cluster with 2 Standard_DS3_v2 workers. Moreover, after 2-4 hours the job fails with an Out Of Memory exception. Looking at the metrics tab of the cluster, we notice a data spill to disk happening. The screenshot below shows the data spill.

EDDatabricks_1-1703760391974.png

Code snippets

The code snippet below shows step 3; the static join on the big Delta table. In essence, the big Delta table is loaded, de-duplicated and joined to the streaming records. Every time a batch is processed, the big Delta table is re-read, de-duplicated and joined to the batch of streaming records.

# Load big Delta table
big_delta_table = (
    spark.read
    .format('delta')
    .table('big_delta_table ')
)

# De-duplicate big Delta table
c1_id_window = Window.partitionBy('c1').orderBy(F.col('updatedOn').desc())
c1_data = (
    big_delta_table 
    .filter(F.col('c1').isNotNull())
    .withColumn(
        'row_num',
        F.row_number().over(ir_id_window)
    )
    .filter(F.col('row_num') == 1)
    .drop('row_num')
)

# De-duplicate big Delta table
call_sign_id_window = Window.partitionBy('c2').orderBy(F.col('updatedOn').desc())
c2_data = (
    ovr_data
    .filter(F.col('c2').isNotNull() &  F.col('c1').isNull())
    .withColumn(
        'row_num',
        F.row_number().over(call_sign_id_window)
    )
    .filter(F.col('row_num') == 1)
    .drop('row_num')
)

clean_big_delta_table = c1_data.union(c2_data )

# Join streaming records with de-duplicated big Delta table
joined_records = (
    streaming_records
    .join(
        F.broadcast(clean_big_delta_table ),
        join_condition,
        'left'
    )
)

 The following code snippet shows step 4; the foreachbatch logic. The batch is persisted to avoid reprocessing the same data 3 times and then the data is written to 3 distinct sinks. Lastly, the batch is unpersisted.

def write_gold_tables(input_df: DataFrame, batch_id: str):
    input_df.persist()

    # Write to Delta table sink 1

    # Write to Delta table sink 2

    # Write to Delta table sink 3

    input_df.unpersist(blocking=True)

 Questions

  • Is there a way to optimize the code provided above, to increase the speed of the streaming job?
  • How could we avoid the data spill to disk?
  • What is the root cause of the OOM exception?