ai_query() is a built-in Databricks SQL function that lets you invoke a model serving endpoint from a SQL query. This enables data scientists to apply machine learning models (including traditional ML like XGBoost, not just LLMs) directly in dashboards, notebooks, or pipelines. In this tutorial, we’ll demonstrate how to use ai_query() with a customer churn XGBoost model on Databricks – from training and registering the model to getting predictions via SQL.
Important: ai_query() is best suited for interactive and batch inference scenarios within Databricks (e.g. exploring results in a notebook, powering BI dashboards, or scheduled jobs in Delta Live Tables). It is not designed for low-latency, high-QPS online inference for external applications. In those real-time API use cases, you would typically call the model serving endpoint’s REST API directly or use a dedicated serving solution. Here, we focus on using ai_query() for on-demand SQL-based predictions in Databricks.
In this walkthrough, we will:
Create sample data for a customer churn prediction use case.
Train an XGBoost model using this data.
Log and register the model with MLflow (using Unity Catalog for governance).
Deploy the model to Databricks Model Serving (creating a serving endpoint).
Query the model using ai_query() from SQL – demonstrating single-record prediction, batch processing, and aggregations on predictions.
Evaluate performance of ai_query() for a batch of predictions and discuss best practices.
Before you begin, make sure to enable the AI Functions preview for custom models in your Databricks workspace:
Go to Settings > Previews and enable "AI_Query for Custom Models and External Models".
Restart your cluster or SQL warehouse to activate this feature.
Also ensure you have access to Databricks Model Serving (to create a serving endpoint for the model), and a cluster running Databricks Runtime 15.4 LTS or above (for best performance with ai_query()).
First, we need to install XGBoost and import our Python libraries. In a Databricks notebook, you can use %pip to install packages. We also reset the Python session to ensure a clean environment and then import necessary modules for ML and Databricks interactions:
# Install XGBoost if not already available
%pip install xgboost==3.0.2 -q
# Restart Python to ensure the new library is picked up (Databricks utility)
dbutils.library.restartPython()
After restarting, import the libraries for model training, evaluation, and deployment:
import mlflow
import mlflow.xgboost
import xgboost as xgb
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from mlflow.deployments import get_deploy_client
import time
print("XGBoost version:", xgb.__version__)
print("MLflow version:", mlflow.__version__)
This sets up XGBoost and MLflow. We’ll use MLflow to log the model and get_deploy_client to programmatically create a serving endpoint. The version prints confirm we have the correct library versions.
For our example, we'll create a synthetic customer churn dataset with 10,000 customers. Each record has various features (demographics, usage, account info) and a churn label indicating if the customer left. We simulate realistic patterns for churn: e.g., low satisfaction or many support tickets increase churn probability, while having tech support or auto-pay reduces it.
# Create a comprehensive sample dataset for customer churn prediction
np.random.seed(42)
# Generate 10,000 customer records
n_samples = 10000
# Customer features
data = {
'customer_id': range(1, n_samples + 1),
'age': np.random.randint(18, 80, n_samples),
'tenure_months': np.random.randint(1, 120, n_samples),
'monthly_charges': np.round(np.random.uniform(20, 200, n_samples), 2),
'total_charges': np.round(np.random.uniform(100, 8000, n_samples), 2),
'contract_length': np.random.choice([1, 12, 24], n_samples, p=[0.3, 0.4, 0.3]), # months
'payment_delay_days': np.random.exponential(2, n_samples).astype(int),
'support_tickets': np.random.poisson(1.5, n_samples),
'service_calls': np.random.poisson(0.8, n_samples),
'satisfaction_score': np.random.uniform(1, 10, n_samples), # 1 = very unhappy, 10 = very satisfied
'has_tech_support': np.random.choice([0, 1], n_samples, p=[0.6, 0.4]),
'has_online_security': np.random.choice([0, 1], n_samples, p=[0.5, 0.5]),
'paperless_billing': np.random.choice([0, 1], n_samples, p=[0.4, 0.6]),
'auto_pay': np.random.choice([0, 1], n_samples, p=[0.3, 0.7])
}
# Create pandas DataFrame
df = pd.DataFrame(data)
# Create churn probability based on logical rules (higher with low satisfaction, payment issues, etc.)
churn_probability = (
0.1 + # Base probability
0.3 * (df['satisfaction_score'] < 5) + # Low satisfaction increases churn chance
0.2 * (df['payment_delay_days'] > 5) + # Many payment delays increases churn
0.15 * (df['support_tickets'] > 3) + # Many support tickets increases churn
0.1 * (df['contract_length'] == 1) + # Month-to-month contract (no long-term commitment)
0.05 * (df['tenure_months'] < 12) - # New customers (less loyalty)
0.1 * (df['has_tech_support'] == 1) - # Tech support available (reduces churn)
0.05 * (df['auto_pay'] == 1) # Auto pay set up (reduces churn)
)
# Clip probabilities to [0, 1]
churn_probability = np.clip(churn_probability, 0, 1)
# Generate binary churn labels from the probabilities
df['churn'] = np.random.binomial(1, churn_probability)
print(f"Dataset created with {len(df)} records")
print(f"Churn rate: {df['churn'].mean():.3f}")
print(f"Features: {list(df.columns[1:-1])}") # list of feature columns (excluding ID and label)
print("\nFirst 5 rows:")
display(df.head())
We used domain-inspired rules to assign a churn probability to each customer, then sampled the churn outcome. This gives us a balanced, realistic synthetic dataset. The printouts show the dataset size, overall churn rate, and the feature columns for reference, and we display the first few rows to verify the data.
Next, configure MLflow to log models to the Unity Catalog Model Registry and define names for our model and serving endpoint. Update the placeholders with your own catalog and schema if needed:
# Configuration – update these with your actual values
catalog_name = "your_catalog" # e.g. default catalog or a specified one
schema_name = "your_schema" # e.g. a schema/database name
model_name = "customer_churn_xgboost"
endpoint_name = "churn-xgboost-endpoint"
table_name = "customer_churn_data"
# Set MLflow to use Unity Catalog for model registry
mlflow.set_registry_uri("databricks-uc")
print(f"Model will be registered as: {catalog_name}.{schema_name}.{model_name}")
print(f"Endpoint will be created as: {endpoint_name}")
Here we choose a model name and an endpoint name. By setting the MLflow registry URI to "databricks-uc", we ensure the model is registered in Unity Catalog (so it can be easily referenced by catalog.schema.model_name). We’ll use these variables throughout the workflow.
To use ai_query(), we need data accessible via SQL. Let's save our pandas DataFrame as a Spark DataFrame and then to a Delta table. This will allow us to run SQL queries that join the data with model predictions:
# Convert pandas DataFrame to Spark DataFrame and save as a Delta table
spark_df = spark.createDataFrame(df)
full_table_name = f"{catalog_name}.{schema_name}.{table_name}"
spark_df.write \
.mode("overwrite") \
.option("mergeSchema", "true") \
.saveAsTable(full_table_name)
print(f"✅ Created table: {full_table_name}")
print(f"Table has {spark_df.count()} rows")
# Quick verification: show first 5 rows from the table
display(spark.sql(f"SELECT * FROM {full_table_name} LIMIT 5"))
We now have a Delta table customer_churn_data (in the specified catalog and schema) containing all our features and the churn label. We’ll use this table in our SQL queries with ai_query() to fetch features and make predictions.
Now, prepare the data for model training. We’ll split the DataFrame into features (X) and label (y), then into training and test sets:
# Prepare feature matrix X and target vector y
feature_columns = [col for col in df.columns if col not in ['customer_id', 'churn']]
X = df[feature_columns]
y = df['churn']
print(f"Feature columns: {feature_columns}")
print(f"Training data shape: {X.shape}")
print(f"Target distribution: {np.bincount(y)}") # counts of 0/1 labels
# Split the data into training and test sets (80/20 split)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
print(f"Training set size: {X_train.shape[0]}")
print(f"Test set size: {X_test.shape[0]}")
We stratify the split by the churn label to maintain the same churn rate in training and test. The outputs confirm the feature columns and the sizes of each split (8,000 training and 2,000 test samples, given our 10k total).
With data ready, we train a binary classification model using XGBoost. We define some hyperparameters and train an XGBClassifier on the training set, using the test set for evaluation (with early stopping to prevent overfitting):
# Train XGBoost model
xgb_params = {
'objective': 'binary:logistic',
'max_depth': 6,
'learning_rate': 0.1,
'n_estimators': 100,
'subsample': 0.8,
'colsample_bytree': 0.8,
'random_state': 42,
'eval_metric': 'logloss',
'early_stopping_rounds': 10
}
# Create and train the XGBoost model
xgb_model = xgb.XGBClassifier(**xgb_params)
xgb_model.fit(
X_train, y_train,
eval_set=[(X_test, y_test)],
verbose=False # silent training
)
# Make predictions on test set
y_pred = xgb_model.predict(X_test)
y_pred_proba = xgb_model.predict_proba(X_test)[:, 1]
# Evaluate model performance
accuracy = accuracy_score(y_test, y_pred)
print("✅ Model Training Complete!")
print(f"Accuracy: {accuracy:.4f}")
print("Feature importance (top 5):")
# Compute feature importance
feature_importance = pd.DataFrame({
'feature': feature_columns,
'importance': xgb_model.feature_importances_
}).sort_values('importance', ascending=False)
display(feature_importance.head()) # display top 5 important features
print("\nClassification Report:")
print(classification_report(y_test, y_pred))
The model trains quickly on 8k samples. We print out the accuracy and a classification report to see how it performs. We also show the top features influencing the model (for example, we might see satisfaction_score or support_tickets as important predictors of churn, given how we constructed the data). The accuracy and report give an idea of baseline performance on this synthetic data.
Now that the model is trained, we log it to MLflow and register it to the Model Registry. This makes the model available for serving. We also log parameters, metrics, and a model signature for transparency and reproducibility:
# Create an input example for model signature
input_example = X_train.head(3)
# Start an MLflow run and log the model and artifacts
with mlflow.start_run(run_name="xgboost_churn_prediction") as run:
# Log model parameters
mlflow.log_params(xgb_params)
# Log additional metadata
mlflow.log_param("features_count", len(feature_columns))
mlflow.log_param("training_samples", len(X_train))
mlflow.log_param("model_type", "XGBoost")
# Log evaluation metrics
mlflow.log_metric("accuracy", accuracy)
mlflow.log_metric("churn_rate", y.mean())
# Create a model signature for schema
signature = mlflow.models.infer_signature(X_train, y_pred_proba)
# Log the XGBoost model to MLflow and register it
model_info = mlflow.xgboost.log_model(
xgb_model=xgb_model,
artifact_path="xgboost_model",
signature=signature,
input_example=input_example,
registered_model_name=f"{catalog_name}.{schema_name}.{model_name}"
)
run_id = run.info.run_id
model_uri = model_info.model_uri
print("✅ Model logged successfully!")
print(f"Run ID: {run_id}")
print(f"Model URI: {model_uri}")
print(f"Registered as: {catalog_name}.{schema_name}.{model_name}")
We used mlflow.xgboost.log_model which saves the model and registers it under the name we specified (including catalog and schema, so it’s a Unity Catalog model). The infer_signature captures the input and output schema (the model outputs a single probability). The print statements confirm the model is registered and provide its URI and run ID for reference.
At this point, you should be able to see the model in the Databricks Model Registry (in the UI) with version 1 ready.
With a registered model, we set up a Model Serving endpoint. This endpoint loads the model and provides a REST API that ai_query() and other clients can call to get predictions. We use the MLflow Deployment client to create the endpoint programmatically:
# Initialize the Databricks deployment client for Model Serving
client = get_deploy_client("databricks")
# Create a serving endpoint with the registered model
try:
print(f"Creating endpoint: {endpoint_name}")
endpoint = client.create_endpoint(
name=endpoint_name,
config={
"served_entities": [
{
"entity_name": f"{catalog_name}.{schema_name}.{model_name}",
"entity_version": "1", # serve model version 1
"workload_size": "Small", # Small workload (1 CPU) for demo
"scale_to_zero_enabled": True # auto-scale down when not in use
}
],
"traffic_config": {
"routes": [
{
"served_model_name": f"{model_name}-1",
"traffic_percentage": 100
}
]
}
}
)
print(f"✅ Endpoint '{endpoint_name}' created successfully!")
print("⏳ Endpoint is being deployed... This may take 5-10 minutes.")
except Exception as e:
if "already exists" in str(e).lower():
print(f"ℹ️ Endpoint '{endpoint_name}' already exists.")
print("Checking if it needs to be updated...")
# Optionally, we could update the endpoint with new config if needed
existing_endpoint = client.get_endpoint(endpoint=endpoint_name)
print(f"Current status: {existing_endpoint.get('state', {}).get('ready', 'Unknown')}")
else:
print(f"❌ Error creating endpoint: {e}")
raise
We specified workload_size: Small . A Small CPU workload is Databricks’ lowest tier for custom models. Exact vCPU & RAM vary by cloud/region. In a real scenario or heavier loads, you might use larger sizes. We also allow the endpoint to scale to zero when idle (to save cost). The endpoint is configured to serve 100% of traffic with model version 1. The create_endpoint call may take a few minutes asynchronously to actually load the model. We handle the case where the endpoint might already exist (to avoid duplication).
Note: You can alternatively create and manage serving endpoints via the Databricks UI (Models > Serving). Here we show the programmatic approach for completeness.
Since deploying the model takes a few minutes, we implement a helper function to poll the endpoint status until it’s ready (or a timeout is reached):
def wait_for_endpoint(endpoint_name, max_wait_minutes=15):
"""Wait for the serving endpoint to be ready."""
max_wait_time = max_wait_minutes * 60
start_time = time.time()
print(f"⏳ Waiting for endpoint '{endpoint_name}' to be ready...")
while time.time() - start_time < max_wait_time:
try:
endpoint_info = client.get_endpoint(endpoint=endpoint_name)
state = endpoint_info.get('state', {}).get('ready', 'UNKNOWN')
elapsed_minutes = (time.time() - start_time) / 60
print(f"[{elapsed_minutes:.1f} min] Endpoint status: {state}")
if state == "READY":
print("✅ Endpoint is ready!")
return True
elif state in ["FAILED", "CANCELLED"]:
print(f"❌ Endpoint deployment failed with state: {state}")
return False
time.sleep(30) # wait 30 seconds before next status check
except Exception as e:
print(f"Error checking status: {e}")
time.sleep(30)
print("❌ Timeout waiting for endpoint to be ready.")
return False
# Wait for the endpoint to finish deploying
is_ready = wait_for_endpoint(endpoint_name)
This function checks the endpoint’s ready status every 30 seconds and prints updates. Once the status becomes "READY", we know the model is live and can accept queries. We invoke it and store is_ready (which will be True if the endpoint is up successfully).
In practice, you could also just wait a few minutes or check the UI for when the serving endpoint is ready.
Before using ai_query(), it’s good to test the model via the standard REST API. We can do this using the client.predict method on a small sample to ensure the endpoint is working as expected:
if is_ready:
try:
# Prepare a small test batch (e.g., 3 records from X_test)
test_sample = X_test.head(3)
# Format the input as expected by model serving (using the 'dataframe_split' format)
test_input = {
"dataframe_split": {
"columns": test_sample.columns.tolist(),
"data": test_sample.values.tolist()
}
}
# Invoke the model serving endpoint for predictions
response = client.predict(endpoint=endpoint_name, inputs=test_input)
print("✅ Direct endpoint test successful!")
print("Sample predictions:", response)
print("Actual labels:", y_test.iloc[:3].tolist())
} except Exception as e:
print(f"❌ Error testing endpoint: {e}")
is_ready = False
In the code above, we took the first 3 test records and formatted them as a JSON payload. The client.predict returns the model's predictions (for XGBoost, by default this will return the predicted probabilities or classes depending on how the model is logged – typically a probability for binary classification in our case). We print the predictions alongside the actual labels for a quick sanity check (they should correspond reasonably, e.g., higher probability for actual churners).
This step simulates what an external application could do via REST. If this call is successful, it confirms our model serving endpoint is functioning. Now we’re ready to use ai_query() from SQL.
ai_query() FunctionThe ai_query(endpoint, request) function in Databricks SQL allows us to query the model from a SQL query. We’ll start with a simple example: get the churn prediction for a few specific customers by calling ai_query() within a SELECT statement.
# Test ai_query() on individual records via SQL
if is_ready:
print("🚀 Testing ai_query() function...")
# Example 1: Single-row predictions for customers 1-5
test_query_1 = f"""
SELECT
customer_id,
age,
tenure_months,
monthly_charges,
satisfaction_score,
churn AS actual_churn,
ai_query(
'{endpoint_name}',
named_struct(
'age', CAST(age AS DOUBLE),
'tenure_months', CAST(tenure_months AS DOUBLE),
'monthly_charges', CAST(monthly_charges AS DOUBLE),
'total_charges', CAST(total_charges AS DOUBLE),
'contract_length', CAST(contract_length AS DOUBLE),
'payment_delay_days', CAST(payment_delay_days AS DOUBLE),
'support_tickets', CAST(support_tickets AS DOUBLE),
'service_calls', CAST(service_calls AS DOUBLE),
'satisfaction_score', CAST(satisfaction_score AS DOUBLE),
'has_tech_support', CAST(has_tech_support AS DOUBLE),
'has_online_security', CAST(has_online_security AS DOUBLE),
'paperless_billing', CAST(paperless_billing AS DOUBLE),
'auto_pay', CAST(auto_pay AS DOUBLE)
),
returnType => 'DOUBLE'
) AS churn_probability
FROM {full_table_name}
WHERE customer_id <= 5
ORDER BY customer_id
"""
try:
result_1 = spark.sql(test_query_1)
print("✅ ai_query() test successful! Results:")
display(result_1)
except Exception as e:
print(f"❌ Error with ai_query(): {e}")
print("Make sure the 'AI_Query for Custom Models' preview is enabled.")
In the SQL query above:
We select some customer info and the actual churn label.
We call ai_query('<endpoint_name>', named_struct(...), returnType => 'DOUBLE') to get the predicted churn probability from our model for each row.
The named_struct(...) is how we pass feature values from each row into the model. We need to cast each value to the correct type (Double in this case) to match the model’s input schema.
We specify returnType => 'DOUBLE' because our model returns a probability (a double). This hints to Databricks what type to expect so it can handle the result properly.
This query will return a table of customers (1 through 5) with a new column churn_probability containing the model’s prediction. We ordered by customer_id just for deterministic output. The display(result_1) will show something like:
| customer_id | age | tenure_months | monthly_charges | satisfaction_score | actual_churn | churn_probability |
|---|---|---|---|---|---|---|
| 1 | ... | ... | ... | ... | 0 | 0.08 |
| 2 | ... | ... | ... | ... | 1 | 0.75 |
| ... | ... | ... | ... | ... | ... | ... |
Each row’s churn_probability is computed by sending that row’s features to the XGBoost model via the serving endpoint. The fact that we can do this in a SQL query is powerful – we could join this with other tables or use it in any SQL-based analysis.
ai_query() Usage ExamplesSo far, we did a point query for a few customers. Next, let's demonstrate more complex, batch usage of ai_query() within SQL. We’ll cover two scenarios:
(a) Batch filtering and aggregation: Get predictions for a subset of customers and compute aggregate statistics.
(b) Risk segmentation: Use ai_query() to categorize customers by churn risk level and compute summary stats for each segment.
if is_ready:
print("🔥 Advanced ai_query() examples...")
# Example 2: Batch predictions with filtering and aggregation
test_query_2 = f"""
WITH predictions AS (
SELECT
customer_id,
satisfaction_score,
CAST(ai_query(
'{endpoint_name}',
named_struct(
'age', CAST(age AS DOUBLE),
'tenure_months', CAST(tenure_months AS DOUBLE),
'monthly_charges', CAST(monthly_charges AS DOUBLE),
'total_charges', CAST(total_charges AS DOUBLE),
'contract_length', CAST(contract_length AS DOUBLE),
'payment_delay_days', CAST(payment_delay_days AS DOUBLE),
'support_tickets', CAST(support_tickets AS DOUBLE),
'service_calls', CAST(service_calls AS DOUBLE),
'satisfaction_score', CAST(satisfaction_score AS DOUBLE),
'has_tech_support', CAST(has_tech_support AS DOUBLE),
'has_online_security', CAST(has_online_security AS DOUBLE),
'paperless_billing', CAST(paperless_billing AS DOUBLE),
'auto_pay', CAST(auto_pay AS DOUBLE)
),
returnType => 'DOUBLE'
) AS DOUBLE) AS churn_probability
FROM {full_table_name}
WHERE satisfaction_score < 5
LIMIT 100
)
SELECT
COUNT(*) AS total_customers,
AVG(churn_probability) AS avg_churn_probability,
MIN(churn_probability) AS min_churn_probability,
MAX(churn_probability) AS max_churn_probability
FROM predictions
"""
try:
result_2 = spark.sql(test_query_2)
print("✅ Batch analysis successful!")
display(result_2)
except Exception as e:
print(f"❌ Error with batch analysis: {e}")
# Example 3: Risk segmentation using ai_query
test_query_3 = f"""
WITH predictions AS (
SELECT
customer_id,
age,
satisfaction_score,
ai_query(
'{endpoint_name}',
named_struct(
'age', CAST(age AS DOUBLE),
'tenure_months', CAST(tenure_months AS DOUBLE),
'monthly_charges', CAST(monthly_charges AS DOUBLE),
'total_charges', CAST(total_charges AS DOUBLE),
'contract_length', CAST(contract_length AS DOUBLE),
'payment_delay_days', CAST(payment_delay_days AS DOUBLE),
'support_tickets', CAST(support_tickets AS DOUBLE),
'service_calls', CAST(service_calls AS DOUBLE),
'satisfaction_score', CAST(satisfaction_score AS DOUBLE),
'has_tech_support', CAST(has_tech_support AS DOUBLE),
'has_online_security', CAST(has_online_security AS DOUBLE),
'paperless_billing', CAST(paperless_billing AS DOUBLE),
'auto_pay', CAST(auto_pay AS DOUBLE)
),
returnType => 'DOUBLE'
) AS churn_risk
FROM {full_table_name}
LIMIT 100
)
SELECT
CASE
WHEN churn_risk >= 0.7 THEN 'High Risk'
WHEN churn_risk >= 0.4 THEN 'Medium Risk'
ELSE 'Low Risk'
END AS risk_category,
COUNT(*) AS customer_count,
AVG(churn_risk) AS avg_risk_score,
AVG(satisfaction_score) AS avg_satisfaction
FROM predictions
GROUP BY 1
ORDER BY avg_risk_score DESC
"""
try:
result_3 = spark.sql(test_query_3)
print("✅ Risk segmentation successful!")
display(result_3)
except Exception as e:
print(f"❌ Error with risk segmentation: {e}")
Let’s break down these queries:
Batch prediction with filtering (Query 2): We use a common table expression (CTE) predictions to first select 100 customers with satisfaction_score < 5 (i.e., dissatisfied customers), along with their churn probability via ai_query(). Wrapping this in a CTE is a good practice when doing aggregates on ai_query() results, to avoid potential issues with aggregations directly around the UDF call. In the outer query, we simply compute the number of customers and the average, min, max churn probability for that subset. This can help answer questions like "What's the average churn risk among low-satisfaction customers?" The result (result_2) will be a single row of aggregated stats.
Risk segmentation (Query 3): Here we score 100 random customers and then categorize each as High, Medium, or Low risk based on the churn probability (churn_risk). Using a CASE expression in SQL, we assign a risk category and then aggregate by that category. We count how many customers fall into each risk bucket and also compute the average churn risk and average satisfaction score per group. The result (result_3) would look like:
| risk_category | customer_count | avg_risk_score | avg_satisfaction |
|---|---|---|---|
| High Risk | 15 | 0.85 | 3.2 |
| Medium Risk | 40 | 0.55 | 5.1 |
| Low Risk | 45 | 0.20 | 7.8 |
This is just an illustrative example; the actual numbers depend on the random data. But it demonstrates how ai_query() enables powerful analyses: you can join model predictions with other data and use the full expressiveness of SQL (filters, GROUP BY, CASE statements, etc.) to derive insights, all in one query.
Note: The use of WITH predictions AS (...) in both queries is intentional. Currently, ai_query() may not be allowed directly inside certain aggregations or may cause the query planner to error if not isolated. The CTE pattern ensures we first materialize the predictions, then aggregate or further process them.
Finally, let's assess the performance of ai_query() when scoring a larger batch of data. We’ll invoke the model on 1000 records and measure the time taken:
if is_ready:
print("⚡ Performance testing ai_query()...")
# Query 1000 records and count how many are high risk
perf_query = f"""
WITH predictions AS (
SELECT
customer_id,
CAST(ai_query(
'{endpoint_name}',
named_struct(
'age', CAST(age AS DOUBLE),
'tenure_months', CAST(tenure_months AS DOUBLE),
'monthly_charges', CAST(monthly_charges AS DOUBLE),
'total_charges', CAST(total_charges AS DOUBLE),
'contract_length', CAST(contract_length AS DOUBLE),
'payment_delay_days', CAST(payment_delay_days AS DOUBLE),
'support_tickets', CAST(support_tickets AS DOUBLE),
'service_calls', CAST(service_calls AS DOUBLE),
'satisfaction_score', CAST(satisfaction_score AS DOUBLE),
'has_tech_support', CAST(has_tech_support AS DOUBLE),
'has_online_security', CAST(has_online_security AS DOUBLE),
'paperless_billing', CAST(paperless_billing AS DOUBLE),
'auto_pay', CAST(auto_pay AS DOUBLE)
),
returnType => 'DOUBLE'
) AS DOUBLE) AS churn_probability
FROM {full_table_name}
LIMIT 1000
)
SELECT
COUNT(*) AS processed_records,
COUNT(CASE WHEN churn_probability > 0.5 THEN 1 END) AS high_risk_customers
FROM predictions
"""
try:
start_time = time.time()
perf_result = spark.sql(perf_query)
perf_data = perf_result.collect()[0] # collect the single result row
end_time = time.time()
duration = end_time - start_time
print(f"✅ Processed {perf_data['processed_records']} records in {duration:.2f} seconds")
print(f"Found {perf_data['high_risk_customers']} high-risk customers")
except Exception as e:
print(f"❌ Performance test failed: {e}")
This query scores 1000 customers (without any filter in this case) and counts how many of them have a churn probability over 0.5, all in one shot. By collecting the result to the driver, we measure the total time taken.
The output might be, for example: “Processed 1000 records in 2.37 seconds, Found 512 high-risk customers.” This gives a rough sense of throughput. A couple of seconds for 1000 predictions implies ai_query() handled about 400–500 rows per second on a single Small endpoint (your actual performance may vary).
Important Performance Notes:
ai_query() calls out to the model serving endpoint for each partition of data. Under the hood, it will batch requests, but the throughput will depend on the endpoint’s capacity and the cluster’s parallelism. For higher volumes, consider using a more powerful serving endpoint (e.g., Medium or Large) and a larger cluster or SQL warehouse.
This approach is great for interactive and batch inference (as we’ve shown). However, for extremely low-latency requirements (e.g., per-request latency of a few milliseconds in an external app), the overhead of Spark and the ai_query() function call might be too high. In those cases, using the endpoint’s REST API directly from the application (bypassing Spark) or deploying the model in a real-time serving system would be more appropriate.
Created sample data – A synthetic 10,000-row customer churn dataset with rich features.
Trained an XGBoost model – A tree-based classifier to predict churn, achieving reasonable accuracy.
Registered the model – Used MLflow and Unity Catalog to register the model for serving.
Deployed a serving endpoint – Set up Databricks Model Serving for the XGBoost model.
Used ai_query() – Queried the model via SQL, retrieving predictions in both single-record and batch modes, and integrated results with SQL analytics.
ai_query() Capabilities Demonstrated:Single record predictions – Embedding model inference in point queries (like looking up a specific customer’s churn probability).
Batch processing – Scoring multiple rows and even entire tables within a SQL query.
Aggregations on predictions – Computing metrics (averages, counts) on model outputs using SQL (WITH queries and aggregations).
Segmentation analysis – Combining predictions with SQL logic (CASE statements, group by) to derive business insights (e.g., risk tiers).
Performance at scale – Running 1000 predictions in a query to gauge throughput and demonstrate ai_query() for moderate batch sizes.
Production Setup: If you plan to use this in production pipelines, consider using a more robust serving endpoint (Medium/Large) for high-volume inference, and ensure your cluster/SQL warehouse has sufficient resources.
Model Monitoring: Set up monitoring for the deployed model – track drift, latency, and accuracy metrics. Over time, retrain the model as data evolves.
Integration: Use ai_query() in DLT pipelines, scheduled workflows, or dashboards to deliver ML-driven insights. For example, a dashboard could show real-time churn risk for incoming support tickets by querying the model with ai_query().
Optimization: Tune the serving endpoint’s scaling and your query patterns. If concurrency is high, you might enable scaling or even deploy multiple endpoints for different use cases. If certain queries are slow, consider indexing or caching frequent results (via materialized views or Delta Live Tables with ai_query() logic).
Note: For external applications that need real-time predictions (e.g., a web app calling an API on each user action with sub-second latency), you should call the model serving endpoint directly via its REST API rather than through ai_query(). The ai_query() function shines in interactive analytics and batch jobs within Databricks, but adding Spark SQL in the loop can introduce latency that wouldn’t be acceptable in a tightly latency-sensitive context.
Finally, you can find the full notebook and code for this example on GitHub . Feel free to clone it and adapt to your own models and use cases. With ai_query(), traditional ML models like XGBoost can be seamlessly integrated into your SQL workflows, enabling data scientists to leverage machine learning results directly in analyses and reports.
You must be a registered user to add a comment. If you've already registered, sign in. Otherwise, register and sign in.