Additionally, I log the model as shown below, with MicrosoftResnet50Model being my custom inference class with load_context and predict methods:
with mlflow.start_run():
model_info = mlflow.pyfunc.log_model(
REGISTERED_MODEL_NAME,
python_model=MicrosoftResnet50Model(),
input_example=api_input_example,
artifacts={"model_path": MODEL_PATH},
pip_requirements=[
f"transformers=={transformers.__version__}",
"torch==2.0.1"
],
signature=signature,
registered_model_name=f"{CATALOG}.{SCHEMA}.{REGISTERED_MODEL_NAME}"
)