Louis_Frolio
Databricks Employee
Databricks Employee

Greetings @Jayachithra , I did some digging and came up with some helpful tips/hints to help you along.  

On Issue 1 (explicit MLflow context): expected behavior once you realize that mapInPandas spawns isolated Python worker processes, not threads. No shared state with the driver at all. The silent failure is the real trap — no exception, no warning, just empty traces with no indication of why. Your pattern is the right fix.

Two refinements worth adding. If you're on a Unity Catalog-enabled workspace, make sure the worker has access to the UC experiment path, not just the tracking URI. Permissions don't carry over automatically and the failure mode is the same — silent, no traces, no clues.

Also, if Spark reuses worker processes across partitions (which it can), your current setup re-runs autolog configuration on every partition. A module-level flag prevents that:

_worker_initialized = False

def process_partition(batch_iter):
    global _worker_initialized
    import os, mlflow

    if not _worker_initialized:
        os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] = "false"
        mlflow.set_tracking_uri(_tracking_uri)
        mlflow.set_experiment(_experiment_name)
        mlflow.openai.autolog()
        _worker_initialized = True
    ...

No semantic change — just avoids redundant setup work per partition.

On Issue 2 (async export): your diagnosis and fix are correct. atexit only fires on clean process exit, and Spark kills workers externally when the partition completes. Synchronous logging is the right call — LLM latency makes the 100-500ms overhead basically irrelevant.

One thing worth knowing: you can set MLFLOW_ENABLE_ASYNC_TRACE_LOGGING=false at the cluster or job environment variable level so it applies to all workers without touching the partition function. Setting it in both places is harmless, just redundant.

On Issue 3 (no parent-child linking): autolog creates root spans by design to stay decoupled from user code, so this isn't changing without an explicit API addition. Three realistic options depending on what matters most to you.

If flat traces with logical grouping are good enough, attach a stable identifier with mlflow.set_span_attribute() around each LLM call — same doc_id or batch_id per document gives you something to filter on in the trace UI without relying on timestamp proximity.

If you actually need trace hierarchy, the honest answer is to drop autolog and go manual in the worker:

def process_partition(batch_iter):
    # MLflow setup as before...

    for batch_df in batch_iter:
        for _, row in batch_df.iterrows():
            with mlflow.start_trace("doc_processing") as trace:
                response = client.chat.completions.create(
                    model="endpoint",
                    messages=[...],
                )
                mlflow.log_dict(
                    {
                        "doc_id": row.doc_id,
                        "request": [...],
                        "response": response.model_dump(),
                    },
                    artifact_file=f"traces/{row.doc_id}.json",
                )
            yield batch_df

You lose the automatic OpenAI-specific parsing autolog gives you, but you get full control over nesting and attributes. Reasonable trade depending on your needs.

If strict hierarchy is non-negotiable and you want to keep autolog, there's an architectural option worth considering: centralize LLM calls to a single process. Spark handles data prep and partitioning; a driver-side or external service owns all OpenAI calls and MLflow tracing; results flow back to Spark. You give up some parallelism for a clean trace tree. Makes sense if provider rate limits are already your real bottleneck anyway.

On the re-evaluation issue: this one deserves more attention because it's where real money quietly disappears. mapInPandas is lazy — the partition function re-runs every time the downstream plan gets evaluated. createOrReplaceTempView creates a reference to the logical plan, not a materialized result, so cache() alone doesn't protect you.

Cache and force:

result_df = input_df.mapInPandas(process_partition, schema=schema)
result_df = result_df.cache()
result_df.count()  # forces single evaluation

Or write once and re-read:

result_df = input_df.mapInPandas(process_partition, schema=schema)
result_df.write.mode("overwrite").saveAsTable("my_llm_results")
stable_df = spark.table("my_llm_results")

Cheers, Lou

View solution in original post