Batch Python UDFs in Unity Catalog and Spark SQL

stefan-vulpe
New Contributor II

Hello datanauts 🧑🏻‍🚀,

I'm encountering a conceptual challenge regarding Batch Python UDFs within Spark SQL in Databricks. My primary question is: can Batch Python UDFs be used directly via Spark SQL? As a Databricks beginner, I'm seeking to understand the underlying reasons for the behavior I'm observing.

For testing, I've created a simple Batch Python UDF in my default catalog. This UDF processes the NYC Taxi trips sample dataset, simulating a data processing step by converting an iterator of pd.Series tuples into JSON strings, returning an iterator of a single pd.Series. Below is the UDF's code:

CREATE OR REPLACE FUNCTION process_taxi_data(
  pickup TIMESTAMP,
  dropoff TIMESTAMP,
  trip_dist DOUBLE,
  fare DOUBLE,
  pickup_zip INT,
  dropoff_zip INT
) RETURNS STRING LANGUAGE PYTHON PARAMETER STYLE PANDAS HANDLER 'handler_function' AS $$
import json
import pandas as pd
from typing import Iterator, Tuple

def handler_function(batch_iter: Iterator[Tuple[pd.Series, pd.Series, pd.Series, pd.Series, pd.Series, pd.Series]]) -> Iterator[pd.Series]:
    for pickup, dropoff, trip_dist, fare, pickup_zip, dropoff_zip in batch_iter:
        # Process each row in the batch
        results = []
        
        for i in range(len(trip_dist)):
            # Extract individual row values
            p_time = pickup.iloc[i]
            d_time = dropoff.iloc[i]
            dist = trip_dist.iloc[i] * 2  # Double the distance
            fare_val = fare.iloc[i]
            p_zip = pickup_zip.iloc[i]
            d_zip = dropoff_zip.iloc[i]
            
            # Create result object for this row
            result = {
                "trip_distance": dist,
                "fare": format(float(fare_val), '.2f'),
                "pickup_time": str(p_time),
                "dropoff_time": str(d_time),
                "pickup_zip": str(p_zip),
                "dropoff_zip": str(d_zip)
            }
            
            results.append(json.dumps(result))
        
        # Return a Series with one JSON string per row
        yield pd.Series(results)
$$;

The creation of the function executes successfully, and I'm able to use it within my SQL Editor for some test queries like the following:

SELECT
  pandas_process_taxi_data(
    tpep_pickup_datetime, tpep_dropoff_datetime, trip_distance, fare_amount, pickup_zip, dropoff_zip
  ) AS processed_data
FROM
  samples.nyctaxi.trips
LIMIT 5;

-- Output:
processed_data
"{""trip_distance"": 2.8, ""fare"": ""8.00"", ""pickup_time"": ""2016-02-13 21:47:53"", ""dropoff_time"": ""2016-02-13 21:57:15"", ""pickup_zip"": ""10103"", ""dropoff_zip"": ""10110""}"
"{""trip_distance"": 2.62, ""fare"": ""7.50"", ""pickup_time"": ""2016-02-13 18:29:09"", ""dropoff_time"": ""2016-02-13 18:37:23"", ""pickup_zip"": ""10023"", ""dropoff_zip"": ""10023""}"
"{""trip_distance"": 3.6, ""fare"": ""9.50"", ""pickup_time"": ""2016-02-06 19:40:58"", ""dropoff_time"": ""2016-02-06 19:52:32"", ""pickup_zip"": ""10001"", ""dropoff_zip"": ""10018""}"
"{""trip_distance"": 4.6, ""fare"": ""11.50"", ""pickup_time"": ""2016-02-12 19:06:43"", ""dropoff_time"": ""2016-02-12 19:20:54"", ""pickup_zip"": ""10044"", ""dropoff_zip"": ""10111""}"
"{""trip_distance"": 5.2, ""fare"": ""18.50"", ""pickup_time"": ""2016-02-23 10:27:56"", ""dropoff_time"": ""2016-02-23 10:58:33"", ""pickup_zip"": ""10199"", ""dropoff_zip"": ""10022""}"

The problem occurs when trying to use this UDF from a Python script / notebook. It seems that spark.sql can somehow find a reference to it, it executes it, but returns back only NULL rows. Here is a code example:

simple_test = spark.sql("""
SELECT
  <my_catalog>.<my_schema>.process_taxi_data(
    tpep_pickup_datetime, tpep_dropoff_datetime, trip_distance, fare_amount, pickup_zip, dropoff_zip
  ) AS processed_data
FROM
  samples.nyctaxi.trips;
""")

print("Simple UDF test:")
simple_test.groupBy('processed_data').count().show()
display(simple_test.select("processed_data").limit(5))

# Output:
processed_data:string
Simple UDF test:
+--------------+-----+
|processed_data|count|
+--------------+-----+
|          NULL|21932|
+--------------+-----+
processed_data
null
null
null
null
null

Does anyone know why that happens? Is there any way of having shared UDFs between Spark Sessions and Unity Catalog? Any response would be highly appreciated 🙏🏻!