szymon_dybczak
Esteemed Contributor III

Hi @YS1 ,

As a workaround you can rewrite pivot to sql with case statements.

Below Pivot:

data = [
    ("ProductA", "North", 100),
    ("ProductA", "South", 150),
    ("ProductA", "East", 200),
    ("ProductA", "West", 250),
    ("ProductB", "North", 300),
    ("ProductB", "South", 350),
    ("ProductB", "East", 400),
    ("ProductB", "West", 450)
]

columns = ["product", "region", "sales"]


df = spark.createDataFrame(data, columns)


df.createOrReplaceTempView("sales_data")


sql_query = """
SELECT * FROM (
    SELECT product, region, sales
    FROM sales_data
) PIVOT (
    SUM(sales) FOR region IN ('North', 'South', 'East', 'West')
)
"""


pivot_df_sql = spark.sql(sql_query)

display(pivot_df_sql)

is equivalent to this sql:

sql_query = """
SELECT
    product,
    SUM(CASE WHEN region = 'North' THEN sales ELSE 0 END) AS North,
    SUM(CASE WHEN region = 'South' THEN sales ELSE 0 END) AS South,
    SUM(CASE WHEN region = 'East' THEN sales ELSE 0 END) AS East,
    SUM(CASE WHEN region = 'West' THEN sales ELSE 0 END) AS West
FROM sales_data
GROUP BY product
"""

display(spark.sql(sql_query))