Hi @PJ11 ,
As per documentation:
UpSetPlot internally works with data based on Pandas data structures: a Series when all you care about is counts, or a DataFrame when you’re interested in visualising additional properties of the data, such as with the UpSet.add_catplot method.
UpSetPlot expects the Series or DataFrame to have a MultiIndex as input, with this index being an indicator matrix. Specifically, each category is a level in the pandas.MultiIndex with boolean values.
Make sure to convert your your spark dataframes to pandas dataframes.
df_pandas = df.toPandas()
Please check below end-to-end example, that includes:
1. Initial data preparation (to simulate some data you can read from a table)
import pandas as pd
from upsetplot import UpSet
import matplotlib.pyplot as plt
from pyspark.sql.functions import explode, array, lit, col
from pyspark.sql.types import BooleanType
data = [
(1, False, False, False, False),
(2, False, True, False, False),
(3, False, True, True, False),
(4, False, False, True, False),
(5, True, False, False, False),
(6, True, True, False, False),
(7, False, False, False, True),
(8, True, False, True, False),
(9, True, True, True, False),
(10, False, True, False, True),
(11, False, False, True, True),
(12, False, True, True, True),
(13, True, False, False, True),
(14, True, True, False, True),
(15, True, True, True, True),
(16, True, False, True, True),
]
counts = [169, 159, 110, 108, 84, 69, 49, 46, 43, 35, 32, 31, 22, 20, 14, 9]
columns = ['member_id', 'Feature_A', 'Feature_B', 'Feature_C', 'Feature_D']
df_spark = spark.createDataFrame([dict(zip(columns, row)) for row in data])
df_with_counts = df_spark.withColumn('count', lit(0))
for i, count in enumerate(counts):
df_with_counts = df_with_counts.union(
df_spark.filter(col('member_id') == data[i][0])
.withColumn('count', array([lit(1)] * count))
.withColumn('count', explode(col('count')))
)
df_with_counts = df_with_counts.filter(col('count') != 0)
df_transformed = df_with_counts.drop('member_id', 'count')
for col_name in ['Feature_A', 'Feature_B', 'Feature_C', 'Feature_D']:
df_transformed = df_transformed.withColumn(col_name, col(col_name).cast(BooleanType()))
display(df_transformed)
2. Conversions to pandas objects and then creation of the upset plot
df_pandas = df_transformed.toPandas()
df_counts = df_pandas.groupby(['Feature_A', 'Feature_B', 'Feature_C', 'Feature_D']).size().reset_index(name='count')
index = pd.MultiIndex.from_frame(df_counts[['Feature_A', 'Feature_B', 'Feature_C', 'Feature_D']])
data_counts = pd.Series(df_counts['count'].values, index=index)
print("Series with MultiIndex:")
display(data_counts)
upset = UpSet(
data_counts,
subset_size='auto',
show_counts='%d',
sort_by='cardinality',
sort_categories_by='-input',
facecolor='blue',
)
fig = plt.figure(figsize=(12, 8))
upset.plot(fig=fig)
plt.show()
Final output:
