One of the most common use cases in any long-running business is the ability to create a searchable catalog of a company’s collective body of knowledge. Much of that knowledge base is wrapped up in files of various types, such as PowerPoint, Word, PDF files, and sometimes plain images. To gain the most insight, it’s necessary not only to catalog the text in these files, but also to be able to analyze and classify the images that accompany the text.
At the time of this blog’s writing, Databricks doesn’t include any image classification models by default. In this brief article, I’m going to discuss, and give some basic skeleton code for hosting a visual classification model within Databricks.
While you could use Databricks to host almost any model from any vendor, for the purposes of our demonstration, I’m going to use Microsoft’s Resnet-50 visual model, mostly due to the breadth of training data on which it has been trained, as well as its fast inference time.
Let’s get a notebook started!
Open your notebook, and execute the following code in it. This simply imports some libraries then restarts the Python kernel to ensure we have the required packages.
# Install the necessary python packages.
%pip install transformers torch
# Restart python so the above packages are actually used.
dbutils.library.restartPython()
For this, we’ll be using transformers and PyTorch, though you can feel free to implement this using torch.hub instead. (I’m choosing transformers because of the simplicity of its AutoImageProcessor, which can more easily prepare images for processing.)
Now that we’ve gotten our requirements, we’re going to define the vision model as follows:
import io, base64, torch, mlflow.pyfunc
import pandas as pd
from mlflow.types.schema import ColSpec
from PIL import Image
from transformers import AutoImageProcessor, ResNetForImageClassification
class MicrosoftResnetFifty(mlflow.pyfunc.PythonModel):
# Initialize the model, setting the image processor and the classification model.
def __init__(self):
super(MicrosoftResnetFifty, self).__init__()
self.processor = AutoImageProcessor.from_pretrained('microsoft/resnet-50')
self.model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50')
# Prepares the image for inference by converting any known image to a non-transparent PNG.
def prepare_image_for_inference(self, image_bytes):
img = Image.open(io.BytesIO(image_bytes))
buf = io.BytesIO()
img.convert("RGB").save(buf, format='PNG')
img_bytes = buf.getvalue()
image = Image.open(io.BytesIO(img_bytes))
return image
# Pre-processes the image, and executes inference, returning the predicted label.
def classify(self, image):
inputs = self.processor(image, return_tensors="pt")
with torch.no_grad():
logits = self.model(**inputs).logits
predicted_label = logits.argmax(-1).item()
return self.model.config.id2label[predicted_label]
# Takes any image of any known type, converts the image to PNG, and passes the image to the model for inference.
def classify_image(self, image_bytes):
image = self.prepare_image_for_inference(image_bytes)
prediction = self.classify(image)
return prediction
# This is the external interface that Databricks will use to pass the input to the classifier.
def predict(self, model_input: pd.DataFrame):
return [self.classify_image(base64.b64decode(image)) for image in model_input['image']]
A couple of notes on the code above:
For user-friendliness of the model, I’m including code that will ensure that the images passed in conform to a specific format–namely a non-transparent PNG. I’ve put that transformation in the prepare_image_for_inference function.
You may also notice the call to torch.no_grad() within the classify function. This disables PyTorch’s autograd feature, which is useful while training a vision model, but slows things down when we’re doing inference.
We’re accepting a Pandas Dataframe in the predict function, which will make our model capable of batch inference across several images in one call.
Transformers is a Hugging Face library that we’ll leverage for loading the model, and PyTorch will help with GPU acceleration, which we’ll get to in a moment.
The fun thing about general purpose foundational models is that you can really have fun with them, so we will. To test the model, I’m going to use a photo of the best pup in the world. Here’s a photo of Ada (Lovelace) Morton. Feel free to use this or any other image you want.
Pull the image into the same folder as your notebook, and then add a cell with the following code, modifying it for your specific image filename:
# Here we'll test the class to ensure it returns proper values.
with open('best-pup-ever.webp', 'rb') as file:
file_content = base64.b64encode(file.read())
input_example = pd.DataFrame({"image": [file_content]})
MicrosoftResnetFifty().predict(input_example)
If all goes well, you may see a warning or two, but then you should see something similar to the following in the output:
['flat-coated retriever']
This is a signal that the prediction works, and that the model is functioning.
It’s possible that you notice a warning in the output that looks something like this:
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
This warning is from Hugging Face’s transformers library, and is a sign that you’re doing inference using a CPU and not a GPU. If you see that, and you’re using serverless, you can change the serverless configuration to use an A10 processor clicking the server dropdown (where it says “Connected”) and editing the serverless configuration, by adding an A10 accelerator to the environment. While this will likely speed up inference, it’s not technically required, so your other option is to simply ignore the warning altogether.
The next thing we need to do is to register the model. To do this, we first define the input and output schemas that will be used by the consumers of the model, and we’ll combine those schemas into a model signature. Then, we’ll register the model in the Unity Catalog model registry.
Here’s sample code that accomplishes this task:
import pandas as pd
from mlflow.types import Schema, ColSpec
from mlflow.models.signature import ModelSignature
input_schema = Schema([ColSpec("binary", "image")])
output_schema = Schema([ColSpec("string")])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
# Setting the registry, so this model will be stored in Unity Catalog.
mlflow.set_registry_uri("databricks-uc")
# Starting a run, and registering the model.
with mlflow.start_run() as run:
mlflow.pyfunc.log_model(
python_model=MicrosoftResnetFifty(),
artifact_path="infer_model",
input_example=input_example,
signature=signature
)
model_uri = f"runs:/{run.info.run_id}/infer_model"
registered_model_name = "main.default.resnet50"
mlflow.register_model(model_uri, registered_model_name)
For our final step, we’re going to set up the model to be served with a model serving endpoint. This step is again optional, depending on how you’re going to use the model, but if you intend on calling the model from outside Databricks, you’ll need to establish a model serving endpoint.
To do this, navigate to the Models tab in the sidebar.
Search for the model you intend to use, and then click on its name to get more details about the model.
Once you go to the details page of the model, you should see a big blue button on the top right of the screen saying, “Serve this Model”. Click it.
Give your endpoint a name. You may choose to use CPU compute type or GPU compute type. GPU will be faster, but more expensive in general. The resnet-50 model can be served on CPU or GPU. There are plenty of other options on that page, including options to establish rate limits to the model, apply tags, or even set up serverless usage policies and alerts, although these are outside the scope of this blog, it’s worth your time to look through the page and see what options are available to you.
One feature I do want you to consider, depending on your use case, is the Route Optimized Endpoints. If you plan on incorporating this endpoint into a data pipeline where you may have a great volume or velocity of traffic, you’ll want to check this box, as it’ll greatly speed up the time to inference, by adding greater concurrency and drastically increasing the queries-per-second limit.
Once you’re satisfied with the options you’ve selected, click the Save button. At this point, you’ll be taken to a page that will show you the current status of the endpoint.
Once the deployment of the model is complete, we’ll do some inference via the model serving endpoint.
Now that the model serving endpoint is running, you want to copy the invocation endpoint from the serving endpoint details page. You should see the URL near the top of the page, along with a button to copy it to your clipboard. Below is some sample code that will prepare your images for inference, then call the model serving endpoint in the backend.
For maintainability, I suggest putting this code in a brand new notebook, separate from the model registration notebook.
import os
import requests
import numpy as np
import pandas as pd
import json
import base64
from PIL import Image
from io import BytesIO
# Converts a Pandas Dataframe into a JSON string that will
# be accepted by the model serving endpoint.
def to_dataframe_split_json(df):
obj = {'dataframe_split': json.loads(df.to_json(orient='split'))}
return json.dumps(obj, allow_nan=True)
# Fetches the proper authorization headers that will be needed to call the
# model serving endpoint.
def get_headers():
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
headers = {'Authorization': f'Bearer {token}', 'Content-Type': 'application/json'}
return headers
# This function calls the model serving endpoint and returns the results. Make sure you’re using a URL that fits with your cloud provider.
def infer(dataset):
endpoint_url = 'https://your-workspace-id.azuredatabricks.net/serving-endpoints/resnet50/invocations'
headers = get_headers()
post_data = to_dataframe_split_json(dataset)
response = requests.request(
method='POST',
headers=headers,
url=endpoint_url,
data=post_data)
if response.status_code != 200:
raise Exception(f'Request failed with status {response.status_code}, {response.text}')
return response.json()
# (Helper function to base64 encode an image.)
def get_image_bytes(image_path):
with open(image_path, 'rb') as f:
return base64.b64encode(f.read()).decode('utf-8')
# Put multiple images in the dataframe to test batch inference.
image_df = pd.DataFrame({"image":
[
get_image_bytes('cartoon-rabbit.jpg'),
get_image_bytes('sloth.png'),
get_image_bytes('best-pup-ever.webp')
]})
# Call the infer function.
results = infer(image_df)
# Print the results to the console.
print([[y for y in x.split(', ')] for x in results['predictions']])
The code above will work well in a notebook environment within Databricks itself, but if you’re going to be using this from a Databricks app or some other external caller, you’ll likely need to modify parts of this, especially the get_headers() function, as different environments require different token acquisition methods.
You could also go beyond simply having a user-initiated call to the AI model, and process images using a Databricks Declarative Pipeline flow, by using the ai_query function.
While Databricks doesn’t provide any foundational image classification models by default (as of the writing of this blog), it’s a relatively straightforward task to download and register your own image model within Databricks, and this capability can be key to unlocking the value from many of your unstructured data, ranging from simple images to PowerPoint presentations, Word documents and PDFs.
Happy building!
You must be a registered user to add a comment. If you've already registered, sign in. Otherwise, register and sign in.