cancel
Showing results for 
Search instead for 
Did you mean: 
Generative AI
Explore discussions on generative artificial intelligence techniques and applications within the Databricks Community. Share ideas, challenges, and breakthroughs in this cutting-edge field.
cancel
Showing results for 
Search instead for 
Did you mean: 

Custom sentence transformer for indexing

Ulfzerk
New Contributor

Hi! 

i would like to use my own sentence transformer to create a vector index. 

It is not a problem using mlflow sentence-transformer flavour, it works fine with: 

mlflow.sentence_transformers.log_model(
    model,
    artifact_path="model",
    signature=signature,
    input_example=sentences,
    registered_model_name=registered_model_name)
  
  model_uri = f"runs:/{run.info.run_id}/model"
  registered_model = mlflow.register_model(
        model_uri=model_uri,
        name=registered_model_name
    )

What i want to use is a pyfunc flavour because i want to add a optional preprocessing step as addtional functional that is glued to a model. 

Unfortunatly i can't find any documentation or reference on what methods should custom mlflow.pyfunc.PythonModel implement. 
i tired something like this: 

import mlflow.pyfunc
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
class MyDataModel(BaseModel):
    field1: str
    field2: int
    field3: float

def process_object(obj: MyDataModel) -> str:
    return f"{obj.field1} {obj.field2} {obj.field3}"

class CustomSentenceTransformerModel(mlflow.pyfunc.PythonModel):

    def load_context(self, context):
        # Load the Sentence Transformer model
        self.model = SentenceTransformer('all-MiniLM-L6-v2')

    def process_object(self, obj:MyDataModel):
        # Define your custom processing here
        return f"Processed object with value: {obj}"

    def predict(self, context, model_input):
        # This method is required for MLflow's pyfunc models
        return self.model.encode(model_input)
    
    def encode(self, input):
        return self.model.encode(input)

Yet is is not possible to use it for indexing tables. 
I know that i can just run a notebook that will create a new column with vector embeddings, but thats not the point here. 

I just get error: 

Index creation failed
Failed to call Model Serving endpoint: embedding_pyfunc.

Without any justification/logs anything!

 

1 REPLY 1

mark_ott
Databricks Employee
Databricks Employee

To use a custom MLflow pyfunc model for sentence-transformers with preprocessing, you need to comply with the expected interface of mlflow.pyfunc.PythonModel, especially the predict method. The method signature, data handling, and serialization are key points. Below is a direct answer with practical explanation and guidelines.

Required Methods for mlflow.pyfunc.PythonModel

The only method you must implement is predict(self, context, model_input).

  • context: MLflow-provided info (artifacts, configs, etc.).

  • model_input: The input passed during inference (usually Pandas DataFrame, NumPy array, or Python native types).

Guidelines and Typical Pattern

  • Load everything needed in load_context, which runs once when the model is loaded by MLflow.

  • Accept both batch (DataFrame/array) and single-input cases in predict.

  • The output of predict should be directly serializable (ideally array-like or DataFrame).

Example Template

python
import mlflow.pyfunc from sentence_transformers import SentenceTransformer import pandas as pd class CustomSentenceTransformerModel(mlflow.pyfunc.PythonModel): def load_context(self, context): self.model = SentenceTransformer('all-MiniLM-L6-v2') def preprocess(self, row): # Custom preprocessing - join columns, etc. return f"{row['field1']} {row['field2']} {row['field3']}" def predict(self, context, model_input): # Accept DataFrame, Series, or list # If DataFrame, apply preprocessing if isinstance(model_input, pd.DataFrame): texts = model_input.apply(self.preprocess, axis=1).tolist() elif isinstance(model_input, list): texts = [self.preprocess(x) if isinstance(x, dict) else x for x in model_input] else: texts = [str(model_input)] return self.model.encode(texts)

Key Points for Indexing Tables

  • When serving/inferencing, the input must be a DataFrame, array, or compatible structure; MLflow Model Serving expects this.

  • If you want to process tables, accept a DataFrame in predict, preprocess each row, and then encode.

  • All logic for optional preprocessing must be inside predict.

Troubleshooting the "Index creation failed" Error

  • The error likely means predict does not consume the input structure as expected, or the output is not serializable.

  • Ensure you return standard Python objects (lists, arrays, DataFrames); avoid returning custom objects or types that cannot be serialized easily.

  • Check that your model serving environment has all dependencies (sentence-transformers, etc.).

Final Recommendations

  • Implement only load_context and predict, where predict handles any preprocessing.

  • Return vector outputs in formats compatible with downstream tooling (usually NumPy arrays or lists).

  • Test your model locally first:

    python
    import pandas as pd data = pd.DataFrame([{"field1": "hello", "field2": 3, "field3": 4.2}]) model.predict(None, data)

Join Us as a Local Community Builder!

Passionate about hosting events and connecting people? Help us grow a vibrant local community—sign up today to get started!

Sign Up Now