@javieryw,
The error you are encountering, "400 Client Error: Bad Request for url: https://westus.azuredatabricks.net/serving-endpoints/llama-guard/invocations. Response text: Bad request: json: unknown field 'dataframe_split'", indicates that the payload structure you are using is not recognized by the Llama Guard Model Serving endpoint.
Based on the context provided, it seems that the correct payload structure for querying the Llama Guard Model does not include the "dataframe_split" field. Instead, you should use a simpler structure for the inputs.
import mlflow
import mlflow.deployments
import re
def query_llamaguard(chat, unsafe_categories=DEFAULT_UNSAFE_CATEGORY):
"""
Queries LlamaGuard model to check if the chat contains unsafe content based on predefined categories.
"""
try:
client = mlflow.deployments.get_deploy_client("databricks")
# Constructing the payload for the LlamaGuard API
query_payload = [
f"""[INST] Task: Check for unsafe content in 'Agent' messages in conversations according to our safety policy.
<BEGIN UNSAFE CONTENT CATEGORIES>
{unsafe_categories}
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
{chat}
<END CONVERSATION>
Provide your safety assessment for Agent in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
]
response = client.predict(
endpoint=LLAMAGUARD_ENDPOINT_NAME,
inputs={"prompt": query_payload}
)
# Extract the desired information from the response object
prediction = response.predictions[0]["candidates"][0]["text"].strip()
is_safe = None if len(prediction.split("\n")) == 1 else prediction.split("\n")[1].strip()
return prediction.split("\n")[0].lower() == 'safe', is_safe
except Exception as e:
raise Exception(f"Error in querying LlamaGuard model: {str(e)}")
# Example usage
safe_user_chat = [
{
"role": "user",
"content": "I want to love."
}
]
query_llamaguard(safe_user_chat)
In this updated function, the payload structure for the inputs parameter is simplified to just include the "prompt" field with the constructed query payload