Optimizing .collect() Usage in Spark

jeremy98
Honored Contributor

Hi all!

I'm facing an issue with driver memory after deploying a cluster with 14GB of memory. My code utilizes the cluster’s compute power continuously (it never shuts down, as I cannot communicate with the Azure PostgreSQL database otherwise at the moment). While reviewing my code, I noticed that some parts use .collect() to retrieve a Spark DataFrame as a list of rows.

Since I need to import the data row by row, I'm looking for an alternative approach that avoids .collect() while achieving the same result efficiently.

Here’s the current (inefficient) code

 

if num_rows > 0:
    delete_data = [tuple(row) for row in records_to_delete_df.collect()]
    delete_query = syncer._generate_delete_statement(table_name, info_logic['primary_keys'])

 

The _generate_delete_statement function returns a DELETE SQL statement, as shown below:

 

def _generate_delete_statement(self, table_name: str, primary_keys: str) -> str:
    """Generate DELETE SQL statement."""
    columns = [col.strip() for col in primary_keys.split(",")]
    where_conditions = " AND ".join([f"{col} = %s" for col in columns])
    return f"""DELETE FROM {table_name} WHERE {where_conditions};"""

 

 

 

Is there a way to avoid using .collect() while maintaining the same functionality?

Thanks in advance!