spark.sql with CTEs (10 minutes) VS pyspark code + spark.sql (without CTE) (3 seconds), why?

biafch
Contributor

Hello,

I have two codes with the exact same outcome, one takes 7-10 minutes to load, and the other takes exactly 3 seconds, and I'm just trying to understand why:

This takes 7-10 minutes:

F_IntakeStepsPerDay = spark.sql("""
WITH BASE AS (
    SELECT
        s.JobApplicationFK,
        s.StepDate,
        ja.CandidateFK
    FROM steps AS s
    INNER JOIN jobapplication AS ja ON ja.RawDataIsCurrent = 1
    WHERE s.RawDataIsCurrent = 1
      AND s.StepDate >= DATE_SUB(CURRENT_DATE(), 35)
),

REPARTITIONED_BASE AS (
    SELECT * FROM BASE DISTRIBUTE BY JobApplicationFK
),

BASE_WITH_NEXT AS (
    SELECT
        b.*,
        LEAD(b.StepDate) OVER (PARTITION BY b.JobApplicationFK ORDER BY b.StepDate) AS NextStepDate
    FROM REPARTITIONED_BASE b
),

JOIN_WITH_DATE AS ( 
    SELECT /*+ RANGE_JOIN(f, 7) */
        f.JobApplicationFK,
        f.CandidateFK,
        f.StepDate,
        f.NextStepDate
    FROM BASE_WITH_NEXT f
    INNER JOIN dim_date d 
      ON d.Date >= f.StepDate 
     AND d.Date < COALESCE(f.NextStepDate, DATE_ADD(f.StepDate, 1))
)

SELECT *
FROM JOIN_WITH_DATE 
ORDER BY JobApplicationFK, StepDate
""")

display(F_IntakeStepsPerDay)

 

This takes 3 seconds:

from pyspark.sql.functions import col, lead, expr
from pyspark.sql.window import Window

df = (
    spark.table("steps").alias("s")
    .join(spark.table("jobapplication").alias("ja"), on=col("s.JobApplicationFK") == col("ja.JobApplicationBK"))
    .filter(
        (col("s.RawDataIsCurrent") == 1) &
        (col("ja.RawDataIsCurrent") == 1) &
        (col("s.StepDate") >= expr("DATE_SUB(CURRENT_DATE(), 35)"))
    )
    .select("s.JobApplicationFK", "s.StepDate", "ja.CandidateFK")
)

df_repartitioned = df.repartition("JobApplicationFK")

window_spec = Window.partitionBy("JobApplicationFK").orderBy("StepDate")
df_with_next = df_repartitioned.withColumn("NextStepDate", lead("StepDate").over(window_spec))

df_with_next.createOrReplaceTempView("BASE_WITH_NEXT")
spark.catalog.cacheTable("BASE_WITH_NEXT")

F_IntakeStepsPerDay = spark.sql("""
SELECT /*+ RANGE_JOIN(f, 7) */
    d.Date,
    f.JobApplicationFK,
    f.CandidateFK,
    f.StepDate,
    f.NextStepDate
FROM BASE_WITH_NEXT f
INNER JOIN dim_date d ON d.Date >= f.StepDate AND d.Date < COALESCE(f.NextStepDate, DATE_ADD(f.StepDate, 1))
ORDER BY f.JobApplicationFK, f.StepDate
""")

F_IntakeStepsPerDay.display()

 

I just want to understand why. 

 

Can anyone explain please? just to clarify, even without the repartition in my pyspark code and even without the cache, it still takes 2-3 seconds!