Infer_signature for a dictionary datasets during mlflow registration

skosaraju
New Contributor III

Hello community,

Can you please guide me here. I am trying to build custom Ensemble model where I will be passing a dictionary of datasets to the fit() and predict() with the keys being the model_names and value being the respective datasets for each of the models. The idea behind this is I want to register only 1 ensemble model rather than 5 different models.

I will be instantiating the models based on a config and pass the respective dataset. I am currently stuck at the infer_signature() step coz I am unable to build the right structure that its expecting.

Below is the code snippet for my fit and predict. Can you please help me construct the model_input for infer_signature? I know if I register the models separately, I will be able to. But I want to only register 1 model.

class CustomEnsembleModel(mlflow.pyfunc.PythonModel):

def __init__(self, model_config, dbx_params):
if len(model_config) < 1:
raise ValueError("The model_config must contain at least one model configuration.")

self.model_config = model_config
self.models = {} # Dictionary to store model instances
self.dbx_params = dbx_params

for model in self.model_config:
model_name = model['model_name']
hyper_params = model['hyper_params']
self.models[model_name] = self._get_model_instance(model_name, hyper_params)

def _get_model_instance(self, model_name, hyper_params):
if model_name == 'local_outlier_factor':
return LocalOutlierFactorModel(hyper_params)
elif model_name == 'isolation_forest':
return IsolationForestModel(hyper_params)
else:
raise ValueError(f"Unsupported model name: {model_name}")

def fit(self, input_data_dict):
model_outputs = {}
for model_name, model_instance in self.models.items():
input_df = input_data_dict[model_name]
model_outputs[model_name] = model_instance.fit(input_df)
return model_outputs

def predict(self, input_data_dict):
predictions = {}
for model_name, model_instance in self.models.items():
logger.info(f"Loading input datasets for model: {model_name}")
input_df = input_data_dict[model_name]
logger.info(f"Loaded input datasets for model: {model_name} with shape {input_df.count()}")
predictions[model_name] = model_instance.predict(input_df)
return predictions