Hi @YS1 ,
As a workaround you can rewrite pivot to sql with case statements.
Below Pivot:
data = [
("ProductA", "North", 100),
("ProductA", "South", 150),
("ProductA", "East", 200),
("ProductA", "West", 250),
("ProductB", "North", 300),
("ProductB", "South", 350),
("ProductB", "East", 400),
("ProductB", "West", 450)
]
columns = ["product", "region", "sales"]
df = spark.createDataFrame(data, columns)
df.createOrReplaceTempView("sales_data")
sql_query = """
SELECT * FROM (
SELECT product, region, sales
FROM sales_data
) PIVOT (
SUM(sales) FOR region IN ('North', 'South', 'East', 'West')
)
"""
pivot_df_sql = spark.sql(sql_query)
display(pivot_df_sql)
is equivalent to this sql:
sql_query = """
SELECT
product,
SUM(CASE WHEN region = 'North' THEN sales ELSE 0 END) AS North,
SUM(CASE WHEN region = 'South' THEN sales ELSE 0 END) AS South,
SUM(CASE WHEN region = 'East' THEN sales ELSE 0 END) AS East,
SUM(CASE WHEN region = 'West' THEN sales ELSE 0 END) AS West
FROM sales_data
GROUP BY product
"""
display(spark.sql(sql_query))