cancel
Showing results for 
Search instead for 
Did you mean: 
Generative AI
Explore discussions on generative artificial intelligence techniques and applications within the Databricks Community. Share ideas, challenges, and breakthroughs in this cutting-edge field.
cancel
Showing results for 
Search instead for 
Did you mean: 

Slow Delta write when creating embeddings with mapPartitions

andcch552
New Contributor

I’m trying to generate 35k+ embeddings in Databricks. What I’ve tried so far:

  • Per-row UDF (very slow).
  • Replaced UDF with rdd.mapPartitions to batch API calls, create one Azure client per partition, and call client.embed_documents(texts) in batches. This avoids per-row Python UDF overhead and improves embedding throughput.
  • Measured embedding execution vs Delta write time; embedding materialization is fine but the Delta write (saveAsTable / commit) is now the dominant, slow step. I used persist() to avoid double computation when calling count() before write.

Minimal embedding function I tested (simplified):

from pyspark.sql import Row, DataFrame
from pyspark.sql.types import StructField, StructType, ArrayType, FloatType
import os, time
from langchain_openai import AzureOpenAIEmbeddings

def embed_with_map_partitions_simple(df: DataFrame, column_names: str | list[str], batch_size: int = 128, repartition: int | None = None) -> DataFrame:
    if isinstance(column_names, str):
        column_names = [column_names]
    if repartition:
        df = df.repartition(repartition)
    spark = df.sparkSession
    new_fields = list(df.schema.fields) + [StructField(f"{c}_embedding", ArrayType(FloatType()), True) for c in column_names]
    new_schema = StructType(new_fields)

    def partition_embed(rows_iter):
        rows = list(rows_iter)
        if not rows:
            return iter(())
        client = AzureOpenAIEmbeddings(
            azure_endpoint=os.getenv("openai_api_base"),
            azure_deployment=os.getenv("openai_deployment_name"),
            api_key=os.getenv("openai_api_key"),
            api_version=os.getenv("openai_api_version")
        )
        n = len(rows)
        embeddings_per_column = {c: [None]*n for c in column_names}
        for col in column_names:
            for i in range(0, n, batch_size):
                batch = rows[i:i+batch_size]
                texts = [getattr(r, col, "") or "" for r in batch]
                try:
                    batch_emb = client.embed_documents(texts)
                except Exception:
                    batch_emb = [[0.0]*3072 for _ in texts]
                for j, emb in enumerate(batch_emb):
                    embeddings_per_column[col][i+j] = emb
        for idx, r in enumerate(rows):
            d = r.asDict()
            for c in column_names:
                d[f"{c}_embedding"] = embeddings_per_column[c][idx] or [0.0]*3072
            yield Row(**d)

    return spark.createDataFrame(df.rdd.mapPartitions(partition_embed), schema=new_schema)

 

 

Question: can Databricks advise best practices to reduce Delta write/commit time for this workflow (recommended write options, file sizing/num files, transaction tuning, or cluster/io settings)? Also any guidance on safely persisting large transformed DF before writing and on Stitch/OPTIMIZE usage would be helpful.

Thanks.

 

 

1 REPLY 1

bianca_unifeye
New Contributor III

Hi

You’ve optimised the embedding side really nicely already, batching in mapPartitions and creating one Azure client per partition is exactly what we recommend.

For 35k rows, if embedding is fast but the Delta write/commit is slow, it’s almost always due to:

  • too many small output files, and/or

  • extra passes over the DataFrame, and/or

  • a cluster that’s over-parallelised for the amount of data.

I would suggest to look into these:

Control the number of output files

By default Spark uses something like spark.sql.shuffle.partitions = 200, which means your createDataFrame(...).write can easily produce ~200 tiny files for just 35k rows. The overhead of creating those files + committing metadata often dominates the runtime.

For a dataset of this size, you typically want a small number of files (1–4, maybe 8 max).

Key points:

  • Use coalesce(), not repartition(), right before the write.
    coalesce(n) avoids a shuffle and just reduces the number of output partitions.

  • For 35k vectors, 1–4 files is absolutely fine and usually much faster to commit.

2. Avoid double computation / extra actions

You mentioned using persist() to avoid recomputing when you call count(). That’s good. Two extra tips:

  • Persist after your embedding transform, not on the original DF.

  • Only trigger one action before the final write (e.g. count() or maybe display() for debugging). Don’t call count(), show(), and then write without persistence, or Spark will recompute the whole pipeline multiple times.

3. Tuning cluster size & IO for this workload

For 35k rows of 3072-dim embeddings:

  • You don’t need a huge cluster.
    Too many workers mean too many tiny output tasks and more small files.

  • Often a small cluster (e.g. 1–2 workers with decent memory) is faster end-to-end than a large autoscaling cluster for this kind of “wide but not huge” dataset.

  • Make sure you’re writing to a performant storage account (Premium / general purpose v2). In most managed Databricks setups, DBFS is already backed by appropriate storage, so usually this is fine.

If you see a lot of tiny tasks in the Spark UI, that’s a sign to:

  • lower spark.sql.shuffle.partitions for the job (e.g. 32 or even 8 for this size), and/or

  • coalesce before writing as shown above.

4. Data layout & Delta options (Stitch / OPTIMIZE)

For a one-off creation of 35k embeddings, the Delta housekeeping features (Stitch / OPTIMIZE) usually aren’t needed for performance of the write itself, but they matter if:

  • you will repeatedly append to this table,

  • you will query it a lot (e.g. for vector search candidates), or

  • you accidentally created many small files in early runs.

5. Data type choice for embeddings

You’re using ArrayType(FloatType()), which is fine. A few extra notes:

  • If you’re on a Databricks runtime that supports the VECTOR type (for native vector search), consider storing as VECTOR(3072) – it doesn’t massively change write speed, but it’s the recommended long-term format for similarity search.

  • If you stick to arrays, make sure the schema is stable between runs (same type, same dimension). Schema evolution (new columns or type changes) can add extra overhead due to Delta metadata handling.

Join Us as a Local Community Builder!

Passionate about hosting events and connecting people? Help us grow a vibrant local community—sign up today to get started!

Sign Up Now