Hi @sanjay , To optimize a PySpark UDF calling a SageMaker endpoint, consider batch inference, connection pooling, distributed inference, caching, and data serialization. Here's a code snippet that demonstrates batch inference and connection pooling:
@pandas_udf("double", PandasUDFType.SCALAR_ITER)
def score_udf(batch_df):
sm_rt = boto3.client('runtime.sagemaker', config=Config(retries={'max_attempts': 10}))
result = []
for _, row in batch_df.iterrows():
# Send batch to SageMaker and process results
return result
This code processes data efficiently in batches while reusing connections for better performance.