Hi,
For a PySpark model, which involves also a pipeline, and that I want to register with mlflow, I am using a pyfunc wrapper.
Steps I followed:
1. Pipeline and model serialization and logging (using Volume locally, the logging will be performed in dbfs)
import mlflow # pipeline serialization and logging pipeline_path = '/Volumes/my_catalog/my_db/my_volume/my_pipeline' my_pipeline.write().overwrite().save(pipeline_path) with mlflow.start_run(): mlflow.log_artifact(pipeline_path, 'pipeline') # pyspark model serialization and logging model_path = '/Volumes/my_catalog/my_db/my_volume/my_model' my_model.write().overwrite().save(model_path) with mlflow.start_run(): mlflow.log_artifact(model_path, 'model')2. Wrapper definition, using the serialized pipeline and model through the contextimport mlflow.pyfunc from pyspark.ml import PipelineModel from pyspark.sql import SparkSession class ModelWrapper(mlflow.pyfunc.PythonModel): def load_context(self, context): self.spark = SparkSession.builder.getOrCreate() self.pipeline = PipelineModel.load(context.artifacts["my_pipeline"]) self.model = PipelineModel.load(context.artifacts["my_model"]) def predict(self, context, input): input_df = self.pipeline.transform(input) trans = self.model.transform(input_df) return trans.collect()[0]['prediction']3. log pyfunc wrapper with mlflowimport cloudpickle from mlflow.models import infer_signature mlflow.set_registry_uri("databricks-uc") catalog = 'my_catalog' db = 'my_db' model_name = f"{catalog}.{db}.my_spark_model" with mlflow.start_run(run_name="my_run") as run: input_ex = input output_ex = 5.0 signature = infer_signature(input_ex, output_ex) artifacts = { "my_pipeline": "dbfs:/databricks/mlflow-tracking/<run_id_1>/2b398/artifacts/my_pipeline/pipeline", "my_model": "dbfs:/databricks/mlflow-tracking/<run_id_2>/9a4f51/artifacts/my_model/model" } model_info = mlflow.pyfunc.log_model( python_model = ModelWrapper(), artifacts = artifacts, artifact_path="my_spark_model", registered_model_name=model_name, pip_requirements=[ "mlflow==" + mlflow.__version__, "cloudpickle==" + cloudpickle.__version__, "pyspark==" + pyspark.__version__ ], input_example=input_ex, signature=signature )All of the 3 steps above ran OK.
After the last one, `pipeline` was correctly logged at
dbfs:/databricks/mlflow-tracking/<run_id_3>/123a45/artifacts/my_spark_model/artifacts/pipeline
(including metadata, etc.)
whereas `model` at
dbfs:/databricks/mlflow-tracking/<run_id_3>/123a45/artifacts/my_spark_model/artifacts/model
(including metadata, etc.)
The issue happens when trying to load the pyfunc wrapper:import mlflow from mlflow import MlflowClient latest_model = client.get_model_version_by_alias(model_name, "prod") logged_model = f'runs:/{latest_model.run_id}/my_spark_model' loaded_model = mlflow.pyfunc.load_model(logged_model)with the error:
Py4JJavaError: An error occurred while calling o11140.partitions. : org.apache.hadoop.mapred.InvalidInputException: Input path does not exist: /local_disk0/repl_tmp_data/ReplId-46fea-18625-1d618-7/tmpkxct3u4m/my_spark_model/artifacts/model/metadata
Is the problem related to the fact that the workers do not have access to the path of the artifact or the artifact path info has not been broadcast to them? Is it something different?
Does anybody has a clue how to solve it?
Basically, being able to use serialized Spark model artifacts in the pyfunc wrapper registered with mlflow on Databricks, using unity catalog.
Thanks!