Custom sentence transformer for indexing
- Mark as New
- Bookmark
- Subscribe
- Mute
- Subscribe to RSS Feed
- Permalink
- Report Inappropriate Content
03-13-2025 09:39 AM
Hi!
i would like to use my own sentence transformer to create a vector index.
It is not a problem using mlflow sentence-transformer flavour, it works fine with:
mlflow.sentence_transformers.log_model(
model,
artifact_path="model",
signature=signature,
input_example=sentences,
registered_model_name=registered_model_name)
model_uri = f"runs:/{run.info.run_id}/model"
registered_model = mlflow.register_model(
model_uri=model_uri,
name=registered_model_name
)What i want to use is a pyfunc flavour because i want to add a optional preprocessing step as addtional functional that is glued to a model.
Unfortunatly i can't find any documentation or reference on what methods should custom mlflow.pyfunc.PythonModel implement.
i tired something like this:
import mlflow.pyfunc
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
class MyDataModel(BaseModel):
field1: str
field2: int
field3: float
def process_object(obj: MyDataModel) -> str:
return f"{obj.field1} {obj.field2} {obj.field3}"
class CustomSentenceTransformerModel(mlflow.pyfunc.PythonModel):
def load_context(self, context):
# Load the Sentence Transformer model
self.model = SentenceTransformer('all-MiniLM-L6-v2')
def process_object(self, obj:MyDataModel):
# Define your custom processing here
return f"Processed object with value: {obj}"
def predict(self, context, model_input):
# This method is required for MLflow's pyfunc models
return self.model.encode(model_input)
def encode(self, input):
return self.model.encode(input)Yet is is not possible to use it for indexing tables.
I know that i can just run a notebook that will create a new column with vector embeddings, but thats not the point here.
I just get error:
Index creation failed
Failed to call Model Serving endpoint: embedding_pyfunc.Without any justification/logs anything!
- Mark as New
- Bookmark
- Subscribe
- Mute
- Subscribe to RSS Feed
- Permalink
- Report Inappropriate Content
11-07-2025 08:49 AM
To use a custom MLflow pyfunc model for sentence-transformers with preprocessing, you need to comply with the expected interface of mlflow.pyfunc.PythonModel, especially the predict method. The method signature, data handling, and serialization are key points. Below is a direct answer with practical explanation and guidelines.
Required Methods for mlflow.pyfunc.PythonModel
The only method you must implement is predict(self, context, model_input).
-
context: MLflow-provided info (artifacts, configs, etc.). -
model_input: The input passed during inference (usually Pandas DataFrame, NumPy array, or Python native types).
Guidelines and Typical Pattern
-
Load everything needed in
load_context, which runs once when the model is loaded by MLflow. -
Accept both batch (DataFrame/array) and single-input cases in
predict. -
The output of
predictshould be directly serializable (ideally array-like or DataFrame).
Example Template
import mlflow.pyfunc
from sentence_transformers import SentenceTransformer
import pandas as pd
class CustomSentenceTransformerModel(mlflow.pyfunc.PythonModel):
def load_context(self, context):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
def preprocess(self, row):
# Custom preprocessing - join columns, etc.
return f"{row['field1']} {row['field2']} {row['field3']}"
def predict(self, context, model_input):
# Accept DataFrame, Series, or list
# If DataFrame, apply preprocessing
if isinstance(model_input, pd.DataFrame):
texts = model_input.apply(self.preprocess, axis=1).tolist()
elif isinstance(model_input, list):
texts = [self.preprocess(x) if isinstance(x, dict) else x for x in model_input]
else:
texts = [str(model_input)]
return self.model.encode(texts)
Key Points for Indexing Tables
-
When serving/inferencing, the input must be a DataFrame, array, or compatible structure; MLflow Model Serving expects this.
-
If you want to process tables, accept a DataFrame in
predict, preprocess each row, and then encode. -
All logic for optional preprocessing must be inside
predict.
Troubleshooting the "Index creation failed" Error
-
The error likely means
predictdoes not consume the input structure as expected, or the output is not serializable. -
Ensure you return standard Python objects (lists, arrays, DataFrames); avoid returning custom objects or types that cannot be serialized easily.
-
Check that your model serving environment has all dependencies (
sentence-transformers, etc.).
Final Recommendations
-
Implement only
load_contextandpredict, wherepredicthandles any preprocessing. -
Return vector outputs in formats compatible with downstream tooling (usually NumPy arrays or lists).
-
Test your model locally first:
pythonimport pandas as pd data = pd.DataFrame([{"field1": "hello", "field2": 3, "field3": 4.2}]) model.predict(None, data)