cancel
Showing results for 
Search instead for 
Did you mean: 
Generative AI
Explore discussions on generative artificial intelligence techniques and applications within the Databricks Community. Share ideas, challenges, and breakthroughs in this cutting-edge field.
cancel
Showing results for 
Search instead for 
Did you mean: 

How to get MLflow OpenAI autolog traces from PySpark mapInPandas workers (and some pitfalls)

Jayachithra
New Contributor

Context

I'm running an LLM pipeline on Databricks that distributes OpenAI API calls across Spark workers via mapInPandas. Getting mlflow.openai.autolog() to work on workers required solving three undocumented issues. Sharing here since I couldn't find this covered anywhere.

Issue 1: Workers need explicit MLflow context

The docs say to call mlflow.autolog() on workers. For mlflow.openai.autolog(), that's insufficient. Workers also need the tracking URI and experiment - they don't inherit either from the driver.

# Capture on driver
_tracking_uri = mlflow.get_tracking_uri()
_experiment_name = "/Shared/my-experiment"

# Inside mapInPandas partition function
mlflow.set_tracking_uri(_tracking_uri)
mlflow.set_experiment(_experiment_name)
mlflow.openai.autolog()
# Without all three, autolog silently produces zero traces. No error, no warning.

Issue 2: Span artifacts lost due to async export

Even with the correct setup, most traces appeared in the experiment list, but the "detailed trace view" was broken. Investigation showed that AsyncTraceExportQueue uses a daemon thread with `atexit` for flushing. mapInPandas worker processes are terminated (not exited) when the partition completes, so atexit never fires.

Result: trace metadata (inputs, outputs, tokens) is written synchronously and persists. Span artifacts are written asynchronously and are lost for most traces. In my test with 6 documents, 5 out of 6 had missing artifacts.

Fix:

import os
os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] = "false"

Overhead is ~100-500ms per trace, negligible next to LLM latency.

Issue 3: No parent-child trace linking

Each chat.completions.create() call produces an independent trace. Autolog uses start_span_no_context() in mlflow/openai/autolog.py (line 287), which always creates root spans. There's no mechanism to attach autolog spans to a user-provided parent, even though start_span_no_context already accepts a parent_span parameter.

Processing 6 documents = 6 disconnected traces. Correlation is only possible by timestamp.

Complete pattern

_tracking_uri = mlflow.get_tracking_uri()
_experiment_name = "/Shared/my-experiment"

def process_partition(batch_iter):
import os, mlflow
os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] = "false"
mlflow.set_tracking_uri(_tracking_uri)
mlflow.set_experiment(_experiment_name)
mlflow.openai.autolog()

client = DatabricksOpenAI(workspace_client=WorkspaceClient(host=_host, token=_token))
for batch_df in batch_iter:
for _, row in batch_df.iterrows():
client.chat.completions.create(model="endpoint", messages=[...])
yield batch_df

input_df.mapInPandas(process_partition, schema=schema).collect()

Side discovery: Spark re-evaluation

Autolog also exposed that certain mapInPandas materialization patterns cause Spark to re-evaluate the lazy plan multiple times. I saw 24 traces where 6 were expected - each document processed 4x. The createOrReplaceTempView + spark.table().cache() pattern doesn't guarantee single evaluation. Worth checking if you're seeing unexpected LLM costs.

Environment

  • MLflow 3.10.1
  • Databricks serverless compute
  • databricks-openai / DatabricksOpenAI client
  • Python 3.12

Curious if others have hit these. Any alternative approaches?

1 ACCEPTED SOLUTION

Accepted Solutions

Louis_Frolio
Databricks Employee
Databricks Employee

Greetings @Jayachithra , I did some digging and came up with some helpful tips/hints to help you along.  

On Issue 1 (explicit MLflow context): expected behavior once you realize that mapInPandas spawns isolated Python worker processes, not threads. No shared state with the driver at all. The silent failure is the real trap — no exception, no warning, just empty traces with no indication of why. Your pattern is the right fix.

Two refinements worth adding. If you're on a Unity Catalog-enabled workspace, make sure the worker has access to the UC experiment path, not just the tracking URI. Permissions don't carry over automatically and the failure mode is the same — silent, no traces, no clues.

Also, if Spark reuses worker processes across partitions (which it can), your current setup re-runs autolog configuration on every partition. A module-level flag prevents that:

_worker_initialized = False

def process_partition(batch_iter):
    global _worker_initialized
    import os, mlflow

    if not _worker_initialized:
        os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] = "false"
        mlflow.set_tracking_uri(_tracking_uri)
        mlflow.set_experiment(_experiment_name)
        mlflow.openai.autolog()
        _worker_initialized = True
    ...

No semantic change — just avoids redundant setup work per partition.

On Issue 2 (async export): your diagnosis and fix are correct. atexit only fires on clean process exit, and Spark kills workers externally when the partition completes. Synchronous logging is the right call — LLM latency makes the 100-500ms overhead basically irrelevant.

One thing worth knowing: you can set MLFLOW_ENABLE_ASYNC_TRACE_LOGGING=false at the cluster or job environment variable level so it applies to all workers without touching the partition function. Setting it in both places is harmless, just redundant.

On Issue 3 (no parent-child linking): autolog creates root spans by design to stay decoupled from user code, so this isn't changing without an explicit API addition. Three realistic options depending on what matters most to you.

If flat traces with logical grouping are good enough, attach a stable identifier with mlflow.set_span_attribute() around each LLM call — same doc_id or batch_id per document gives you something to filter on in the trace UI without relying on timestamp proximity.

If you actually need trace hierarchy, the honest answer is to drop autolog and go manual in the worker:

def process_partition(batch_iter):
    # MLflow setup as before...

    for batch_df in batch_iter:
        for _, row in batch_df.iterrows():
            with mlflow.start_trace("doc_processing") as trace:
                response = client.chat.completions.create(
                    model="endpoint",
                    messages=[...],
                )
                mlflow.log_dict(
                    {
                        "doc_id": row.doc_id,
                        "request": [...],
                        "response": response.model_dump(),
                    },
                    artifact_file=f"traces/{row.doc_id}.json",
                )
            yield batch_df

You lose the automatic OpenAI-specific parsing autolog gives you, but you get full control over nesting and attributes. Reasonable trade depending on your needs.

If strict hierarchy is non-negotiable and you want to keep autolog, there's an architectural option worth considering: centralize LLM calls to a single process. Spark handles data prep and partitioning; a driver-side or external service owns all OpenAI calls and MLflow tracing; results flow back to Spark. You give up some parallelism for a clean trace tree. Makes sense if provider rate limits are already your real bottleneck anyway.

On the re-evaluation issue: this one deserves more attention because it's where real money quietly disappears. mapInPandas is lazy — the partition function re-runs every time the downstream plan gets evaluated. createOrReplaceTempView creates a reference to the logical plan, not a materialized result, so cache() alone doesn't protect you.

Cache and force:

result_df = input_df.mapInPandas(process_partition, schema=schema)
result_df = result_df.cache()
result_df.count()  # forces single evaluation

Or write once and re-read:

result_df = input_df.mapInPandas(process_partition, schema=schema)
result_df.write.mode("overwrite").saveAsTable("my_llm_results")
stable_df = spark.table("my_llm_results")

Cheers, Lou

View solution in original post

1 REPLY 1

Louis_Frolio
Databricks Employee
Databricks Employee

Greetings @Jayachithra , I did some digging and came up with some helpful tips/hints to help you along.  

On Issue 1 (explicit MLflow context): expected behavior once you realize that mapInPandas spawns isolated Python worker processes, not threads. No shared state with the driver at all. The silent failure is the real trap — no exception, no warning, just empty traces with no indication of why. Your pattern is the right fix.

Two refinements worth adding. If you're on a Unity Catalog-enabled workspace, make sure the worker has access to the UC experiment path, not just the tracking URI. Permissions don't carry over automatically and the failure mode is the same — silent, no traces, no clues.

Also, if Spark reuses worker processes across partitions (which it can), your current setup re-runs autolog configuration on every partition. A module-level flag prevents that:

_worker_initialized = False

def process_partition(batch_iter):
    global _worker_initialized
    import os, mlflow

    if not _worker_initialized:
        os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] = "false"
        mlflow.set_tracking_uri(_tracking_uri)
        mlflow.set_experiment(_experiment_name)
        mlflow.openai.autolog()
        _worker_initialized = True
    ...

No semantic change — just avoids redundant setup work per partition.

On Issue 2 (async export): your diagnosis and fix are correct. atexit only fires on clean process exit, and Spark kills workers externally when the partition completes. Synchronous logging is the right call — LLM latency makes the 100-500ms overhead basically irrelevant.

One thing worth knowing: you can set MLFLOW_ENABLE_ASYNC_TRACE_LOGGING=false at the cluster or job environment variable level so it applies to all workers without touching the partition function. Setting it in both places is harmless, just redundant.

On Issue 3 (no parent-child linking): autolog creates root spans by design to stay decoupled from user code, so this isn't changing without an explicit API addition. Three realistic options depending on what matters most to you.

If flat traces with logical grouping are good enough, attach a stable identifier with mlflow.set_span_attribute() around each LLM call — same doc_id or batch_id per document gives you something to filter on in the trace UI without relying on timestamp proximity.

If you actually need trace hierarchy, the honest answer is to drop autolog and go manual in the worker:

def process_partition(batch_iter):
    # MLflow setup as before...

    for batch_df in batch_iter:
        for _, row in batch_df.iterrows():
            with mlflow.start_trace("doc_processing") as trace:
                response = client.chat.completions.create(
                    model="endpoint",
                    messages=[...],
                )
                mlflow.log_dict(
                    {
                        "doc_id": row.doc_id,
                        "request": [...],
                        "response": response.model_dump(),
                    },
                    artifact_file=f"traces/{row.doc_id}.json",
                )
            yield batch_df

You lose the automatic OpenAI-specific parsing autolog gives you, but you get full control over nesting and attributes. Reasonable trade depending on your needs.

If strict hierarchy is non-negotiable and you want to keep autolog, there's an architectural option worth considering: centralize LLM calls to a single process. Spark handles data prep and partitioning; a driver-side or external service owns all OpenAI calls and MLflow tracing; results flow back to Spark. You give up some parallelism for a clean trace tree. Makes sense if provider rate limits are already your real bottleneck anyway.

On the re-evaluation issue: this one deserves more attention because it's where real money quietly disappears. mapInPandas is lazy — the partition function re-runs every time the downstream plan gets evaluated. createOrReplaceTempView creates a reference to the logical plan, not a materialized result, so cache() alone doesn't protect you.

Cache and force:

result_df = input_df.mapInPandas(process_partition, schema=schema)
result_df = result_df.cache()
result_df.count()  # forces single evaluation

Or write once and re-read:

result_df = input_df.mapInPandas(process_partition, schema=schema)
result_df.write.mode("overwrite").saveAsTable("my_llm_results")
stable_df = spark.table("my_llm_results")

Cheers, Lou