We're trying to run the bundled sentence-transformers library from SBert in a notebook running Databricks ML 16.4 on an AWS g4dn.2xlarge [T4] instance.
However, we're experiencing out of memory crashes and are wondering what the optimal to run sentence vector encoding in Databricks is.
We have tried three different approaches, but neither really works.
1. Skip spark entirely
In this naive approach, we skip spark entirely and continue to run it in standard Python using the toPandas() function on the Spark DataFrame
projects_pdf = df_projects.toPandas()
max_seq_length = 256
sentence_model_name = "paraphrase-multilingual-mpnet-base-v2"
sentence_model = SentenceTransformer(sentence_model_name)
sentence_model.max_seq_length = max_seq_length
text_to_encode = projects_pdf["project_text"].tolist()
np_text_embeddings = sentence_model.encode(text_to_encode, batch_size=128, show_progress_bar=True, convert_to_numpy=True)
This runs and renders the progress bar nicely, but the problem is now converting back into Delta table.
projects_pdf["text_embeddings"] = np_text_embeddings.tolist()
projects_pdf.to_delta("europe_prod_catalog.ad_hoc.project_recommendation_stage", mode="overwrite")
This part will crash with memory issue ("The spark driver has stopped unexpectedly and is restarting. Your notebook will be automatically reattached.")
2. Use Pandas UDF
The second approach is stolen from StackOverflow and is based on Spark's pandas_udf, but does work four our volume of data.
from sentence_transformers import SentenceTransformer
import pandas as pd
import pyspark.sql.functions as F
from pyspark.sql.types import ArrayType, DoubleType, StringType
from sentence_transformers import SentenceTransformer
sentence_model_name = "paraphrase-multilingual-mpnet-base-v2"
max_seq_length = 256
mpnet_sentence_model = SentenceTransformer(sentence_model_name)
mpnet_sentence_model.max_seq_length = max_seq_length
@F.pandas_udf(returnType=ArrayType(DoubleType()))
def mpnet_encode(x: pd.Series) -> pd.Series:
return pd.Series(mpnet_sentence_model.encode(x, batch_size=128).tolist())
# apply udf and show
project_df_2 = projects_df.withColumn("project_text_embedding", mpnet_encode("project_text"))
project_df_2.write.mode("overwrite").saveAsTable("my_table)
This delays execution, but once you try to save it with saveAsTable, we get the same memory error ("The spark driver has stopped unexpectedly and is restarting. Your notebook will be automatically reattached."). I also couldn't get the progress bar to work here.
3. Use MLFlow Spark UDF
I am not entirely sure what MLFlow does and if it is any different from the previous approach, but I also tried using spark_udf. from sentence_transformers import SentenceTransformer
import mlflow
sentence_model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
sentence_model.max_seq_length = 256
data = "MLflow is awesome!"
signature = mlflow.models.infer_signature(
model_input=data,
model_output=sentence_model.encode(data),
)
with mlflow.start_run() as run:
mlflow.sentence_transformers.log_model(
artifact_path="paraphrase-multilingual-mpnet-base-v2-256",
model=sentence_model,
signature=signature,
input_example=data,
)
model_uri = f"runs:/{run.info.run_id}/paraphrase-multilingual-mpnet-base-v2-256"
print(model_uri)
udf = mlflow.pyfunc.spark_udf(
spark,
model_uri=model_uri,
)
# Apply the Spark UDF to the DataFrame. This performs batch predictions across all rows in a distributed manner.
df_project_embedding = df_projects.withColumn("prediction", udf(df_projects["project_text"]))
This ticks, but you don't see if it makes any progress.
Conclusion
The current workaround is to go with the first approach and skip the spark part by storing it as a file in a Databricks Volume instead. However, this fundamentally tabular data (although it involves vector as a column) and having it in a volume loses all the benefits of Databricks.
Another aspect we considered was to create our own batching solution, but the point of Spark is that it should abstract big data handling, so it also seems wrong.
What is the ideal approach here?