You're facing a classic small files problem in S3, which is challenging to solve efficiently. Your current Auto Loader approach has performance limitations when processing millions of small files individually. Let me suggest several optimizations and alternative approaches to speed up this process.
Key Performance Issues
Your copy_files function processes files one by one with dbutils.fs.cp, creating a lot of overhead
Using batch_df.collect() brings all file paths to the driver, creating memory pressure
Individual S3 operations have high latency, especially when done sequentially.
from pyspark.sql.functions import col, lit, input_file_name, regexp_replace, regexp_extract
from pyspark.sql.types import BooleanType, StringType
from pyspark import SparkFiles
import os
import time
# S3 paths configuration
source_path = "s3://source-bucket/source-folder/"
destination_path = "s3://destination-bucket/destination-folder/"
checkpoint_path = "/tmp/autoloader_checkpoints/file_copy_job"
schema_path = "/tmp/autoloader_schemas/file_copy_job"
# Performance tuning configuration
spark.conf.set("spark.sql.files.maxPartitionBytes", "128m") # Smaller partitions for more parallelism
spark.conf.set("spark.sql.adaptive.enabled", "true") # Enable adaptive query execution
spark.conf.set("spark.default.parallelism", 100) # Adjust based on your cluster size
spark.conf.set("spark.sql.shuffle.partitions", 100) # Adjust based on your cluster size
spark.conf.set("spark.databricks.io.cache.enabled", "true") # Enable IO cache if on Databricks
# Define a UDF for copying files that preserves structure and handles errors
def copy_single_file(src_path, dest_path😞
try:
# Ensure parent directory exists
parent_dir = "/".join(dest_path.split("/")[:-1]) + "/"
dbutils.fs.mkdirs(parent_dir)
# Copy the file
dbutils.fs.cp(src_path, dest_path)
return "success"
except Exception as e:
return f"error: {str(e)}"
# Register the UDF
copy_file_udf = udf(copy_single_file, StringType())
# Set up the Auto Loader stream
file_stream = (spark.readStream
.format("cloudFiles")
.option("cloudFiles.format", "binaryFile") # Read files in binary format
.option("cloudFiles.schemaLocation", schema_path)
.option("cloudFiles.includeExistingFiles", "false")
.option("recursiveFileLookup", "true") # Search all subdirectories
.option("pathGlobFilter", "*") # Process all file types
.option("cloudFiles.useNotifications", "true") # Use S3 notifications if available
.option("cloudFiles.fetchParallelism", 64) # Increase parallelism for listing files
.option("cloudFiles.maxFilesPerTrigger", 10000) # Process more files per batch
.option("cloudFiles.region", "eu-central-1")
.load(source_path)
.select("path", "length", "modificationTime")) # Only select needed columns to reduce memory
# Process each batch efficiently
def process_batch(batch_df, batch_id😞
start_time = time.time()
# Skip empty batches
if batch_df.count() == 0:
print(f"Batch {batch_id}: No files to process")
return
# Get timestamp for logging
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
# Calculate the destination path for each file by preserving directory structure
processed_df = (batch_df
# Create destination path by replacing source path with destination path
.withColumn("relative_path",
regexp_replace("path", source_path, ""))
.withColumn("destination_path",
concat(lit(destination_path), col("relative_path")))
# Apply the copy operation to each file in parallel
.withColumn("copy_result",
copy_file_udf(col("path"), col("destination_path")))
)
# Repartition for better parallelism based on file size distribution
# More partitions = more parallel operations
file_count = batch_df.count()
optimal_partitions = min(max(file_count // 1000, 8), 128) # Between 8-128 partitions
# Force execution and collect metrics
result_df = processed_df.repartition(optimal_partitions).cache()
# Trigger execution and collect stats
success_count = result_df.filter(col("copy_result").startswith("success")).count()
error_count = result_df.filter(col("copy_result").startswith("error")).count()
# Log errors for investigation
if error_count > 0:
errors_df = result_df.filter(col("copy_result").startswith("error"))
print(f"Batch {batch_id}: Found {error_count} errors. Sample errors:")
errors_df.select("path", "copy_result").show(10, truncate=False)
# Optionally write errors to a log location
errors_df.write.mode("append").parquet(f"{destination_path}/_error_logs/{batch_id}")
# Calculate performance metrics
duration = time.time() - start_time
files_per_second = file_count / duration if duration > 0 else 0
# Log summary
print(f"""
Batch {batch_id} completed at {timestamp}:
- Files processed: {file_count}
- Success: {success_count}
- Errors: {error_count}
- Duration: {duration:.2f} seconds
- Performance: {files_per_second:.2f} files/second
""")
# Unpersist to free memory
result_df.unpersist()
# Execute the streaming job
(file_stream.writeStream
.foreachBatch(process_batch)
.option("checkpointLocation", checkpoint_path)
.trigger(availableNow=True) # Process available files and terminate
.start()
.awaitTermination())