I'm working on a task where I transform a dataset and re-save it to an S3 bucket. This involves joining the dataset to two others, dropping fields from the initial dataset which overlapped with fields from the other two, hashing certain fields with pyspark.sql.functions.sha2(col, 256), and writing the result to S3. There is a third join involved, but that is with a <1MB dataset and should be easily handled via broadcast join. A simplified version of my code is below:
gcol = <Column grouped on during data generation process>
initial_names = df.schema.names
# Load datasets
df_600MB = spark.read.parquet(...).withColumnRenamed(id_col, id_col+'_600MB')
df_400MB = spark.read.parquet(...)
df_tiny = pd.read_csv(...) #3kB
df_tiny.loc[df_tiny[gcol].isna(), gcol] = None
df_tiny = spark.createDataFrame(df_3).withColumnRenamed(gcol, gcol+'_tmp')
# Find which fields overlap and drop from df
colset_1 = [c for c in df_600MB.schema.names if (c != gcol and c != id_col)]
colset_2 = [c for c in df_400MB.schema.names if c != gcol]
df = df.drop(*(colset_1 + colset_2))
df_400MB = df_400MB.withColumnRenamed(gcol, gcol+'_400MB')
# Assign row numbers randomly to create many-to-one join between df and df_400MB
win_df = Window.partitionBy(gcol).orderBy(F.rand())
win_df_400MB = Window.partitionBy(gcol+'_400MB').orderBy(F.rand())
df_400MB = df_400MB.withColumn('rn', F.row_number().over(win_df_400MB) - 1)
df = df.join(F.broadcast(df_tiny), on= df[gcol].eqNullSafe(df_tiny[gcol+'_tmp'])) \
.drop(gcol+'_tmp')
#print(df.count()) #2.6Billion
# Cap 'rn' in df so that it is always <= 'rn' in df_400MB
# The max value for 'rn' in df_400MB is contained the 'count' field of df_tiny, hence the above join
df = df.withColumn('rn', F.row_number().over(win_df_2) % F.col('count'))
df = df.join(
df_400MB, how='inner', on = (
(df[gcol].eqNullSafe(df_400MB[gcol+'_400MB'])) & (df['rn'] == df_400MB['rn'])
)
)
#print(df.count()) #2.6Billion
# id_col+'_600MB' is a unique key in df_600MB
df = df.join(df_600MB, on = df[id_col] == df_600MB[id_col+'_600MB'], how='left')
df = df.select(initial_names) # To keep the same schema
# print(df.count()) 2.6Billion
if save_df:
print('Hashing')
df = hash_di(df, fields_to_hash)
print('Saving')
df.write.parquet('S3 bucket name')
Some numbers: The initial dataset is 694GB with 2.6Billion rows and about 100 fields. The other two are 400MB and 600MB, with about 2million and 100million rows respectively, about 6 fields in each. The join is many-to-one, i.e. each row from the larger dataset can only match a single row in each of the smaller datasets. I included calls to df.count() to verify that the dataframe was still the same size after each join, and it was, within 1%. (Oddly enough, df.count() went off without a hitch three times consecutively.)
My cluster configuration had 24 executors with 8 cores and 61GB memory each - r4.2xlarge, for those who use databricks. This comes out to 1.4TB memory and 192 cores. These are the Spark config variables I set:
spark.dynamicAllocation.enabled True
spark.executor.memory 40G
spark.shuffle.file.buffer 1024k
spark.sql.shuffle.partitions 720
spark.network.timeout 360s
spark.maxRemoteBlockSizeFetchToMem 2147483135
spark.sql.adaptive.enabled True
spark.sql.execution.arrow.pyspark.enabled True
spark.default.parallelism 720
As I said, this cluster was able to execute df.count() just fine. However, when I called df.write.parquet, things went poorly. Over the course of the job, spill reached multiple terabytes, several times the size of my data; shuffle reads and writes were in the tens of gigabytes; and after running for many hours, the job failed with the following error:
"
Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: ShuffleMapStage 19 (parquet at NativeMethodAccessorImpl.java:0) has failed the maximum allowable number of times: 4. Most recent failure reason:
org.apache.spark.shuffle.FetchFailedException
...
...
Caused by: java.io.IOException: Failed to connect to /10.41.61.123:4048
"
I should also mention that the stage this failed on was, for some reason, only 82 tasks, far less parallelism than I imagined Spark would use.
I'm guessing an executor died along the way. Does anyone have advice for how I can get this job to run? Should I configure my cluster differently? Is there a spark variable I can set that would make a difference? Any help would be appreciated.