cancel
Showing results for 
Search instead for 
Did you mean: 
Technical Blog
Explore in-depth articles, tutorials, and insights on data analytics and machine learning in the Databricks Technical Blog. Stay updated on industry trends, best practices, and advanced techniques.
cancel
Showing results for 
Search instead for 
Did you mean: 
SergeRielau
Databricks Employee
Databricks Employee

It’s not often that a DBMS surprises me when it comes to SQL; I kind of think I have seen it all. However there is this one feature in Spark SQL that made me go: “Huh! Now that’s cool!” when I first encountered it. In fact, I am baffled that, as far as I know, Spark SQL is the only SQL dialect that has this capability:

Higher-order functions and Lambda functions.

Higher-order functions - which take functions as arguments - are not new outside of SQL. Amazing things can be done in C by passing functions, and I knew lambda functions - which are functions without a name - since my LISP days at University in the early 90’s. Let’s take a look into this unsung corner of SQL and what it can do.

Note: You can find a notebook with all the SQL here.

Example: quicksort

The perhaps most famous example of a higher order function is qsort() residing in C's stdlib:

void qsort(void *, size_t, size_t,
           int (*)(const void *, const void *));

Chances are you used it by choice, in an interview, or as classwork. As you see or remember: qsort() takes a comparator function as an argument. The function accepts pointers to two values, a and b, to compare and it returns an integer. The comparator returns a number whose sign indicates the outcome. Positive means  "a > b", 0 means "a == b" and negative means "a < b".

The definition of the actual comparison is left entirely up to the function. qsort() does not prescribe the meaning.

Here is how it’s used (courtesy of chatGPT):

#include <stdio.h>
#include <stdlib.h>

int compareIntegers(const void *a, const void *b) {
    return (*(int *)a - *(int *)b);
}

int main() {
    int array[] = {5, 2, 8, 1, 3};
    size_t arraySize = sizeof(array) / sizeof(array[0]);

    qsort(array, arraySize, sizeof(array[0]), compareIntegers);

    for (size_t i = 0; i < arraySize; ++i) {
        printf("%d ", array[i]);
    }

    return 0;
}

So how do we do this in SQL? The equivalent of qsort() in SQL is array_sort(array, func).

Unlike qsort() which takes a named function as argument, array_sort() takes a lambda function. Here is the SQL implementation of the example above:

SELECT array_sort(array(5, 2, 8, 1, 3),
                  (a, b) -> a - b);
[1, 2, 3, 5, 8]

This is very dense! (😅 We omitted NULL handling, but so did chatGPT... 🤔)

Lambda Functions

A lambda function is a special kind of expression and uses the -> infix operator.

Syntax

{ param -> expr |
  (param1 [, ...] ) -> expr }

On the left the operator expects an identifier or a tuple of identifiers naming the parameters, e.g., a or (a, b)Note that, unlike in a SQL UDF definition, this signature of the lambda function does not include any types! The lambda function expects that the higher-order function will tell it what these types must be.

In the array_sort() example above, the type is INTEGER because the function was passed an ARRAY<INTEGER>.

On the right side of -> is the implementation of the function. It is an expression that uses the parameters of the lambda function to compute a result. In some cases the type of the result will determine the result type of the higher-order function. Other times, like in the case of array_sort(), the higher order function expects a certain type and specific result values.

In most cases lambda functions tend to be self-contained. That is, the expression is solely based on the parameters. But that is not actually required. For example we can influence the sort order using a temporary variable:

DECLARE sortorder = -1;

SELECT array_sort(array(5, 2, 8, 1, 3),
                   (a, b) -> (a - b) * sortorder);
[8, 5, 3, 2, 1]

Or we can use a column reference, even a lateral correlation:

SELECT
  FROM VALUES(-1), (1) AS t(sortorder),
      LATERAL (SELECT array_sort(array(5, 2, 8, 1, 3),
                                  (a, b) -> (a - b) * sortorder));
 1 [1, 2, 3, 5, 8]
-1 [8, 5, 3, 2, 1]

Yo can even use subqueries in the lambda function expression.

User defined aggregates: Maximum distance on a plane.

While you can define user-defined aggregates using Scala UDF, there is no API to accomplish this using SQL... or is there?

Databricks provides a reduce() (or aggregate()) function that takes an array and allows us to pass an initial value, an aggregation lambda function, and a finalization lambda function. This is pretty much what user-defined aggregation functions do!

All we need to do is first aggregate a group into an array using array_agg(), and then use reduce() to collapse the array into the desired scalar result:

SELECT reduce(array_agg(struct(x, y)), 
              named_struct('x', null::integer, 'y', null::integer, 'len', null::integer),
              (acc, point) -> CASE WHEN acc.len IS NULL
OR
acc.len < point.x * point.x + point.y * point.y
THEN named_struct('x', point.x, 'y', point.y,
'len', point.x * point.x + point.y * point.y)
ELSE acc END,

              acc -> struct(acc.x, acc.y))
 FROM VALUES(1, 10), (2, -10), (-10, 3) AS points(x, y);
{ x: -10, y: 3 }

What’s going on here exactly?

  • We define the shape and initial state of the accumulator which is also our scratchpad.
    In this case, we decide to not only collect the interim max but also the computed length of the vector (or rather its square since we don’t need the actual length).  

  • The accumulator lambda function needs to parameters:
    • acc: is the type of the struct of the initial state.
    • point: is typed to the element of the array we reduce.

The implementation of the lambda function is to compare the previous maximum to the point and update it to the point if necessary. It returns the updated acc value.

  • Since we added a scratchpad in the form of the len field, we use a final lambda function which accepts acc again and returns whatever we want the reduce() function to return. In this case, a struct<x int, y int>.

So this works, but it is not pretty. Can we persist the logic in a UDF?

CREATE FUNCTION max_distance(a array<struct<x int, y int>>)
 RETURN reduce(a
               named_struct('x', null::integer, 'y', null::integer, 'len', null::integer),
               (acc, point) -> CASE WHEN acc.len IS NULL
OR
acc.len < point.x * point.x + point.y * point.y
THEN named_struct('x', point.x, 'y', point.y,
'len', point.x * point.x + point.y * point.y)
ELSE acc END,

               acc -> struct(acc.x, acc.y));

SELECT max_distance(array_agg(struct(x, y)))
  FROM VALUES (1, 10), (2, -10), (-10, 3) AS points(x, y);
{ x: -10, y: 3 }

Note: This isn't the fastest way to perform this particular aggregation, but it serves to illustrate the purpose.

What else is there?

There are a decent number of higher-order functions that you can use. Here is a list as of the writing of the blog:

Conclusion

Higher-order functions using lambda functions are a very powerful, yet under appreciated feature of Spark SQL. They can be used to transform arrays and maps, and perform user-defined aggregations. 

References: