I am running process which has 4 steps.
- Querying s3 file paths from dynamo DB based on certain parameters given by user. (function to do so provided by client, just have to import). Returns a list of files
- Check if those file paths have already been queried. Get distinct files and append to a files delta table.
- Fetch data from s3 file paths that were queried earlier (function to do so provided by client, just have to import, and give file path as a parameter). Returns a list of objects where key is 'timestamp' and value is 'pd.DataFrame'.
- I concatenate all the dataframes from all the objects in the list and append it to a dataframe delta table.
code:
def querying_dynamodb(start_date, end_date):
pitd_objects = []
wmp_file_meta_data = []
query_timestamp1 = time.time()
query_resp = perform_multipart_accel_data_query(env, id,start_date,end_date)
# Traversing the queried data from dynamoDb and putting wmp_metadata to a list. This will be used in avoiding data overlapping.
if len(query_resp) != 0:
for response in query_resp:
id_new = int(response["id"]['N'])
wmp_file_path = response['file_path']['S']
accel_data_file_path = response['accel_data']['S']
ts = response['timestamp']['N']
wmp_file_meta_data.append((query_uuid,start_date,end_date,id_new,wmp_file_path))
else:
return ([], '')
# dbutils.notebook.exit("True")
query_timestamp2 = time.time()
query_difference = (query_timestamp2 - query_timestamp1)
return (wmp_file_meta_data)
def get_distinct_wmp_files(wmp_file_meta_data):
columns = ["uuid", "start_date", "end_date", "id", "wmp_file_path"]
dataframe = spark.createDataFrame(wmp_file_meta_data, columns)
table_name = 'wmp_metadata_temp_'+str(id)
#dataframe is converted to delta table
dataframe.persist(StorageLevel.MEMORY_AND_DISK)
dataframe.createOrReplaceTempView(table_name)
new_wmp_files = spark.sql("SELECT * FROM {} WHERE NOT EXISTS (SELECT 1 FROM wmp_metadata_partitioned WHERE {}.id = wmp_metadata_partitioned.id AND {}.wmp_file_path = wmp_metadata_partitioned.wmp_file_path)".format(table_name,table_name,table_name)) # distinct WMP Files -- avoiding data overlapping
return new_wmp_files
def convert_to_pitd_wout_collect(wmp_file_meta_data):
"""
Function to convert the wmp files to PITD
parameters:
wmp_files: list
output:
PITD objects: list
"""
print("Converting to PITD")
print()
pitd_objects = []
pitd_timestamp1 = time.time()
for files in wmp_file_meta_data:
try:
# Fetching PITD objects from s3 and saving to a list.
# wmp_file_path = files['wmp_file_path']
ts = int(files.split('/')[1].split('.')[0])
pitd = get_pitd_for_file_path(
file_path=files,
data_retriever=s3_dr,
timestamp=ts,
id=id,
)
pitd_objects.append(pitd.to_dict())
except botocore.exceptions.ClientError as e:
if e.response['Error']['Code'] == "404":
print("The object does not exists.")
pitd_timestamp2 = time.time()
pitd_difference = pitd_timestamp2 - pitd_timestamp1
print("PITD conversion successful!")
return (pitd_objects)
def process_pitd_objects(time_data):
section_frames = []
for i, section in enumerate(time_data):
try:
td_df = pd.DataFrame.from_dict(time_data[section])
td_df.index = td_df.index + int(section)
td_df[['a1', 'a2', 'a3', 'roll', 'pitch']] = td_df[['a1', 'a2', 'a3', 'roll', 'pitch']].astype('float64')
section_frames.append(td_df)
except Exception as e:
pass
if section_frames:
complete_frame = pd.concat(section_frames)
complete_frame["index"] = complete_frame.index
complete_frame["id"] = id
complete_frame["uuid"] = query_uuid
return complete_frame
def process_dataframes(pitd_objects, id, uuid):
if pitd_objects:
print("Processing PITD objects...")
pitd_objects_rdd = sc.parallelize(pitd_objects)
section_frames_rdd = pitd_objects_rdd.map(process_pitd_objects)
print("Processing completed. Concatenating dataframes....")
# Flatten the RDD of lists into an RDD of DataFrames
result_df = section_frames_rdd.treeReduce(lambda x, y: pd.concat([x,y]))
print("Concatenation completed. Dumping to delta table...")
status = dump_accel_data(spark.createDataFrame(result_df))
if status:
print("Dumped succesfully...")
else:
print("Dumping Failure.")
return spark.createDataFrame(result_df)
wmp_files_paths = querying_dynamodb(start_date, end_date)
new_wmp_files = get_distinct_wmp_files(wmp_files_paths) # returns a pyspark dataframe
wmp_file_list = new_wmp_files.rdd.map(lambda x: x.wmp_file_path).collect() # convert pyspark dataframe column (wmp_file_path) to a list
pitd_objects = convert_to_pitd_wout_collect(wmp_file_list)
process_dataframes(pitd_objects, id, query_uuid)
This whole code is in a notebook, and multiple (read: hundreds) instance of notebooks run in parallel through a threadpool executor in python. My spark crashes when the data is too much. How can I improve the code?
Cluster details: Driver: i3.4xlarge ยท Workers: c4.4xlarge ยท 4-8 workers ยท On-demand and Spot ยท fall back to On-demand ยท 11.3 LTS (includes Apache Spark 3.3.0, Scala 2.12) ยท us-east-1a (12-20 DBU)