cancel
Showing results forย 
Search instead forย 
Did you mean:ย 
Machine Learning
Dive into the world of machine learning on the Databricks platform. Explore discussions on algorithms, model training, deployment, and more. Connect with ML enthusiasts and experts.
cancel
Showing results forย 
Search instead forย 
Did you mean:ย 

What is the most efficient way of running sentence-transformers on a Spark DataFrame column?

excavator-matt
New Contributor III

We're trying to run the bundled sentence-transformers library from SBert in a notebook running Databricks ML 16.4 on an AWS g4dn.2xlarge [T4] instance.

However, we're experiencing out of memory crashes and are wondering what the optimal to run sentence vector encoding in Databricks is.

We have tried three different approaches, but neither really works.

1. Skip spark entirely

In this naive approach, we skip spark entirely and continue to run it in standard Python using the toPandas() function on the Spark DataFrame

projects_pdf = df_projects.toPandas()
max_seq_length = 256

sentence_model_name = "paraphrase-multilingual-mpnet-base-v2"
sentence_model = SentenceTransformer(sentence_model_name)
sentence_model.max_seq_length = max_seq_length

text_to_encode = projects_pdf["project_text"].tolist()
np_text_embeddings = sentence_model.encode(text_to_encode, batch_size=128, show_progress_bar=True, convert_to_numpy=True)

This runs and renders the progress bar nicely, but the problem is now converting back into Delta table.

projects_pdf["text_embeddings"] = np_text_embeddings.tolist()

projects_pdf.to_delta("europe_prod_catalog.ad_hoc.project_recommendation_stage", mode="overwrite")

This part will crash with memory issue ("The spark driver has stopped unexpectedly and is restarting. Your notebook will be automatically reattached.")

2. Use Pandas UDF

The second approach is stolen from StackOverflow and is based on Spark's pandas_udf, but does work four our volume of data.

 
from sentence_transformers import SentenceTransformer
import pandas as pd
import pyspark.sql.functions as F
from pyspark.sql.types import ArrayType, DoubleType, StringType
from sentence_transformers import SentenceTransformer

sentence_model_name = "paraphrase-multilingual-mpnet-base-v2"
max_seq_length = 256
mpnet_sentence_model = SentenceTransformer(sentence_model_name)
mpnet_sentence_model.max_seq_length = max_seq_length


@F.pandas_udf(returnType=ArrayType(DoubleType()))
def mpnet_encode(x: pd.Series) -> pd.Series:
return pd.Series(mpnet_sentence_model.encode(x, batch_size=128).tolist())

# apply udf and show
project_df_2 = projects_df.withColumn("project_text_embedding", mpnet_encode("project_text"))
project_df_2.write.mode("overwrite").saveAsTable("my_table)
 
This delays execution, but once you try to save it with saveAsTable, we get the same memory error ("The spark driver has stopped unexpectedly and is restarting. Your notebook will be automatically reattached."). I also couldn't get the progress bar to work here.
 
3. Use MLFlow Spark UDF
I am not entirely sure what MLFlow does and if it is any different from the previous approach, but I also tried using spark_udf.

from sentence_transformers import SentenceTransformer
import mlflow

sentence_model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
sentence_model.max_seq_length = 256
data = "MLflow is awesome!"
signature = mlflow.models.infer_signature(
model_input=data,
model_output=sentence_model.encode(data),
)

with mlflow.start_run() as run:
mlflow.sentence_transformers.log_model(
artifact_path="paraphrase-multilingual-mpnet-base-v2-256",
model=sentence_model,
signature=signature,
input_example=data,
)
model_uri = f"runs:/{run.info.run_id}/paraphrase-multilingual-mpnet-base-v2-256"
print(model_uri)

udf = mlflow.pyfunc.spark_udf(
spark,
model_uri=model_uri,
)

# Apply the Spark UDF to the DataFrame. This performs batch predictions across all rows in a distributed manner.
df_project_embedding = df_projects.withColumn("prediction", udf(df_projects["project_text"]))

 This ticks, but you don't see if it makes any progress.
 
Conclusion
The current workaround is to go with the first approach and skip the spark part by storing it as a file in a Databricks Volume instead. However, this fundamentally tabular data (although it involves vector as a column) and having it in a volume loses all the benefits of Databricks.
 
Another aspect we considered was to create our own batching solution, but the point of Spark is that it should abstract big data handling, so it also seems wrong.
 
What is the ideal approach here?
1 REPLY 1

BigRoux
Databricks Employee
Databricks Employee

Spark is designed to handle very large datasets by distributing processing across a cluster, which is why working with Spark DataFrames unlocks these scalability benefits. In contrast, Python and Pandas are not inherently distributed; Pandas dataframes are eagerly evaluated and executed locally, so you can encounter memory issues when working with large datasets. For instance, exceeding around 95 GB of data in Pandas often leads to out-of-memory errors because only the driver node handles all computation, regardless of cluster size.


To bridge this gap, consider using the Pandas API on Spark, which is part of the Spark ecosystem. This API provides Pandas-equivalent syntax and functionality, while leveraging Sparkโ€™s distributed processing to handle larger data volumes efficiently. You can learn more here: https://docs.databricks.com/aws/en/pandas/pandas-on-spark.


In short, the Pandas API on Spark lets you write familiar Pandas-style code but benefit from distributed computation. It greatly reduces memory bottlenecks and scales to bigger datasets than native Pandas workflows allow.


Hope this helps, Louis.

Join Us as a Local Community Builder!

Passionate about hosting events and connecting people? Help us grow a vibrant local communityโ€”sign up today to get started!

Sign Up Now