- Mark as New
- Bookmark
- Subscribe
- Mute
- Subscribe to RSS Feed
- Permalink
- Report Inappropriate Content
09-03-2025 02:21 AM - edited 09-03-2025 02:32 AM
We're trying to run the bundled sentence-transformers library from SBert in a notebook running Databricks ML 16.4 on an AWS g4dn.2xlarge [T4] instance.
However, we're experiencing out of memory crashes and are wondering what the optimal to run sentence vector encoding in Databricks is.
We have tried three different approaches, but neither really works.
1. Skip spark entirely
In this naive approach, we skip spark entirely and continue to run it in standard Python using the toPandas() function on the Spark DataFrame
projects_pdf = df_projects.toPandas()
max_seq_length = 256
sentence_model_name = "paraphrase-multilingual-mpnet-base-v2"
sentence_model = SentenceTransformer(sentence_model_name)
sentence_model.max_seq_length = max_seq_length
text_to_encode = projects_pdf["project_text"].tolist()
np_text_embeddings = sentence_model.encode(text_to_encode, batch_size=128, show_progress_bar=True, convert_to_numpy=True)
This runs and renders the progress bar nicely, but the problem is now converting back into Delta table.
projects_pdf["text_embeddings"] = np_text_embeddings.tolist()
projects_pdf.to_delta("europe_prod_catalog.ad_hoc.project_recommendation_stage", mode="overwrite")
This part will crash with memory issue ("The spark driver has stopped unexpectedly and is restarting. Your notebook will be automatically reattached.")
2. Use Pandas UDF
The second approach is stolen from StackOverflow and is based on Spark's pandas_udf, but does work four our volume of data.
from sentence_transformers import SentenceTransformer
import mlflow
sentence_model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
sentence_model.max_seq_length = 256
data = "MLflow is awesome!"
signature = mlflow.models.infer_signature(
model_input=data,
model_output=sentence_model.encode(data),
)
with mlflow.start_run() as run:
mlflow.sentence_transformers.log_model(
artifact_path="paraphrase-multilingual-mpnet-base-v2-256",
model=sentence_model,
signature=signature,
input_example=data,
)
model_uri = f"runs:/{run.info.run_id}/paraphrase-multilingual-mpnet-base-v2-256"
print(model_uri)
udf = mlflow.pyfunc.spark_udf(
spark,
model_uri=model_uri,
)
# Apply the Spark UDF to the DataFrame. This performs batch predictions across all rows in a distributed manner.
df_project_embedding = df_projects.withColumn("prediction", udf(df_projects["project_text"]))