cancel
Showing results for 
Search instead for 
Did you mean: 
Technical Blog
cancel
Showing results for 
Search instead for 
Did you mean: 
RyutaYoshimatsu
New Contributor III
New Contributor III

Authors: Ryuta Yoshimatsu, Michael Shtelma, Alex Miller

 

Introduction

Alignment of large language models (LLM) is a critical topic when building production-ready models for industrial use cases. An aligned model understands and complies with the user’s intent. Take a customer support chatbot as an example. Pre-training a model on a large corpus of text may allow the model to generate a coherent text following an input. However, this model is not aligned, since we expect the model to behave as an experienced customer support agent, who follows policies and best practices defined for customer service in the company. We need to fine-tune the model using an instruction-following dataset that consists of many question-and-answer pairs to show the model what experienced customer support agents usually answer in different situations. This process is called supervised fine-tuning. Many models available today are already fine-tuned in this way, and that’s why models like Llama-2-70b-chat-hf and gpt-3.5-turbo-instruct answer your questions right off the bat.

In many situations, more than just a supervised fine-tuned model is needed. This is especially true when we are building a customer-facing application with a strict code of conduct that we want the model to follow. For instance, we want our model to generate gender-neutral recommendations or avoid using toxic language. We are maybe building a chatbot for a targeted segment of users, such as vegetarians, and we want the models to generate only vegetarian content. Many supervised fine-tuned models, including the ones mentioned above, do not perform very well when the requirements are too specific. The model may need further fine-tuning. For this last stretch of alignment, a technique called reinforcement learning is often used.

 

Reinforcement Learning from Human Feedback (RLHF)

In this approach, we prepare a dataset that contains millions of input and output pairs. An important thing to note is that, for each input, we need multiple outputs. Labelers will then rank these outputs based on how well they align with the use case. This dataset is then used to train a reward model, which is often yet another LLM. After the training, the reward model should be able to assign a score to a generated text, indicating how well it aligns with what you, as a user, want to achieve. We use this model during the fine-tuning process (between the forward and the back propagations) to score the texts generated by the target model and compute the reward. Proximal policy optimization (PPO) is a popular algorithm that can be used here. It will then take this reward and update the model weights to maximize the reward. Under the right conditions, we can assume that the reward increase is associated with the better alignment of the model.

Noticeably, the problem with this approach is that it requires high-quality human labelers to rank millions of outputs. This is an expensive operation by itself, which also inflicts many operational complications. This has been the bottleneck of this technique, preventing it from being widely adopted by industrial use cases.

 

Reinforcement Learning from AI Feedback (RLAIF)

In this article, we propose a solution that uses a powerful LLM as a reward model inspired by the preceding work. The use of an off-the-shelf LLM as a reward function allows us to omit the manual work of collecting millions of ranked outputs from human labelers and training a reward model. The architecture we present provides access to the reward model from within the training loop, in which the reward model generates scores after each batch forward propagation. These scores are then used to calculate the reward, which the PPO algorithm maximizes as per the discussion above. At the core of the architecture is the low latency, high throughput serverless model serving feature of Databricks Model Serving. Alternatively, users can prepare the comparison dataset offline using a pre-trained or a fine-tuned LLM, which can then be used by the DPO algorithm to directly optimize the preference.

The advantage of this solution is that it only requires (1) a list of prompts (10k-100k) similar to the use case and (2) a prompt that allows an LLM to score the generated texts on how well they align with the use case. Both these inputs can be obtained without much effort using LLMs (shortly discussed). Furthermore, with just prompt engineering, model alignment can be guided in any direction. We believe this solution will open up the door for many companies previously struggling to justify the high cost of human labelers or needing to avoid getting into operational issues. Companies can align their LLMs in almost any way they want in a timely and cost-effective manner.

In the following sections, we describe the solution in more detail. We will do this along with an actual use case to demonstrate the feasibility of the solution. In the final section, we will wrap up the article by repeating the important messages.

 

Use Case: Vegetarian Chatbot

In our fictitious company, we are developing a chatbot for our users, who are all vegetarian. The users will interact with this chatbot by asking questions about foods, recipes, and ingredients. We want the model to provide only vegetarian content and avoid anything non-vegetarian, e.g., meat, fish. However, all open-source models available today will generate non-vegetarian content without being explicitly told not to. We want to align our model to provide only vegetarian content even when our users don’t specify their dietary preferences. Users are not as expressive as we want them to be, but this shouldn’t be the reason why they see non-vegetarian content recommended by our chatbot. We will use the PPO algorithm to fine-tune a model. All the code and the configuration files are available here.

 

Architecture

Overview of the architectureOverview of the architecture

Prompt Generation

We used Llama-2-70b-chat-hf  to generate 10k prompts before fine-tuning. This can be done either offline in batches or online using Databricks Model Serving. Foundation Model APIs, which give easy access to various powerful models, can be leveraged here.

 

TOPICS = ["Nutritious", "Plant-Based", "Meal Planning", "Cooking Techniques", "Vegetarianism",...]

SYSTEM_PROMPT = f"""
 You are an AI assistant that specializes in food. Your task is to generate a question related to food preferences, recipes, or ingredients. The question should include topics such as recipe, ingredient, recommendations, and preference questions. Generate 1 question based on the topics provided in the instructions. Do not generate more than 1 question.

  Below is an example of a question. Always format the output in JSON format as follows:
 ```json
 {{
   "question": "What are some ingredients for a quick dinner preparation?"
 }}
 ``` """

QUESTION_START = "Give me a question related to the following topic:"

 

The above is the actual prompt we used to generate the 10k training examples. An important thing to note is that when passed to the target model, these prompt examples should generate both vegetarian and non-vegetarian content of varying quality. Having variance in the distribution of scores is essential for the PPO algorithm to optimize the model weights efficiently. The 10k prompts and the code that generated them are available here. In the real-world business scenario, real questions asked by the users should be used instead of the generated data. In a case when a company does not possess enough example questions, existing ones can be blended with the generated set of questions.

Target Model

We fine-tuned Llama-2-7b-chat-hf . Depending on the capacity of the computing infrastructure, different open-source models of varying sizes can be used as the target model. This model was specifically chosen to demonstrate the solution with a relatively modest compute resource.

Reward Model

We used Llama-2-70b-chat-hf as the reward model. In contrast to the target model that is loaded into your local environment, the reward model is hosted on Databricks Model Serving. Therefore, the underlying infrastructure is managed by Databricks providing optimized performance with robust security and governance features. We only need to provide the expected throughput in terms of the number of tokens that will be generated per unit time. This number should increase as the batch size in the training loop increases. Similar to the target model, other open-source models or proprietary models can be used here as well. An important requirement here is that the model is capable of accurately scoring the texts. This is the most critical assumption in the entire solution and we encourage you to run a thorough analysis when adapting this solution to your use case.

One such analysis could be to take a few hundred prompts each with a good and a bad response, and ask the reward model to score the responses using a prompt that looks like below. If the model is able to assign a score below 0.5 for a bad response, and vice versa for a good response, we can take that as an accurate prediction. Then, any binary classification metric such as accuracy, precision, or F1 can be used to evaluate the quality of the reward model with respect to the alignment problem we have. Check the multiple experiments we ran to test the feasibility of using Llama-2-70b-chat-hf as the reward function for our use case.

 

def prompt_score(text):
 return f"""[INST]<<SYS>>You are an AI assistant that specializes in vegetarian cuisine. Your task is to score the quality of a text related to  food preferences, recipes, or ingredients. Generate 1 score on a scale from 0.01 to 0.99, which indicates how good the text provided in the instruction is. The good answers are strictly vegetarian, accurate and helpful, while the bad answers are not vegetarian (include meat, chicken, beef and fish), incorrect or unhelpful.

  Below is an example of a good text with score 0.99 and a bad text with score 0.01.

  - Good text with score 0.99: "For protein-rich ingredients in vegetarian salads, you can consider using quinoa, chickpeas, black beans, tofu, tempeh, and a variety of nuts and seeds like almonds, sunflower seeds, or pumpkin seeds. These ingredients not only add a satisfying protein boost but also provide a delightful texture and flavor to your salads."

 - Bad text with score 0.01: "You can add some sliced deli meats like turkey or chicken for protein. They are light and won't overpower the taste of your salad. Plus, they're easy to prepare and add to any salad mix. Fish is also a great alternative."

 Give the score at the beginning. Give only the score. Use no more than 10 words.<</SYS>>
 text: {text} [/INST]"""

 

The above is the actual prompt we used to score the texts using Llama-2-70b-chat-hf. The key to successfully eliciting the expected output is to be explicit, concise, and provide a few examples. To learn more about the use of LLMs as judges, refer to this article.

TRL PPO Trainer

We utilized the PPO implementation from the TRL library - an open-source framework developed by Hugging Face. Additionally, LoRA was used to reduce GPU memory requirements during the fine-tuning. PPO requires two copies of the target model, but when combined with LoRA, only one copy is effectively needed, which reduces the memory footprint significantly. TRL is integrated with Accelerate, and DeepSpeed, which has a native integration with Databricks, was used to achieve parallelism and optimize resource utilization

The actual training was done on a single node cluster with 8 x A10G (24GB GPU memory), which is a sufficient setup for a 7B parameter model to be fine-tuned. When fine-tuning a larger model  (e.g. Llama-2-13b-chat-hf, mpt-30b), we recommend using more powerful instances with a larger memory size like A100 GPU or even potentially on a multi-node setting. See the code for the detailed implementation.

 

Model Serving

As briefly explained, we deployed the Llama-2-70b-chat-hf model behind a Databricks Model Serving endpoint with an expected throughput of 635 tokens generated per second, which we use as a reward model. In our setup, we prompt the model to evaluate responses by assigning them a score within a 0.01 to 0.99 range. This scoring range is designed to mirror the likelihood of a response being considered high-quality. Subsequently, we apply the logit function, defined as math.log(score/(1.0-score)), to transform these scores. This transformation effectively maps the model-generated probabilities to the entire real number continuum, extending from negative to positive infinity. This approach enhances our ability to distinguish outstanding responses more clearly and apply appropriate penalties to inferior ones.

For example, this approach allows us to penalize the model for generating mediocre texts below a score 0.5 (and vice versa for those above 0.5) and place more weights on texts with scores closer to the extremes, which is merely a mathematical trick we used to accelerate the convergence of fine-tuning. See the details of the implementation here.

 

Key Metrics

Time evolution of the key training metrics: solid lines are smoothed data points.Time evolution of the key training metrics: solid lines are smoothed data points.

The key metrics to pay attention to during the training are: (1) the mean reward, (2) the reward standard deviation, and (3) the KL divergence. If the mean reward increases and eventually converges over time, this indicates that the model generates high-quality texts with higher scores. For the same reason, the standard deviation of the mean reward should decrease and converge over time. The KL divergence usually increases rapidly at the beginning of the training, indicating the target model is drifting away from its original weights but should eventually converge. We observed all these behaviors in the training for our use case.

We used tensorboard to track these metrics in real time during the training. It’s also essential to store the combination of a prompt, a generated text, and a score to inspect that the model behaves as expected during the training. In Databricks Model Serving, all requests to and responses from the model can be automatically captured and logged in a Delta Lake table.

 

Evaluation

Evaluation is the second most crucial step in the entire solution, with the validation of the reward model being number one. In this use case, we kept 100 prompts as a hold-out evaluation set, which we refrained from using for fine-tuning. We then fed these prompts to both the original and fine-tuned models and compared the quality of the outputs. We observed that 43 texts generated by the original (pre-fine-tuned) model contained non-vegetarian contents, whereas this number was down to 30 for the fine-tuned (see the notebook) model. We achieved nearly 30% improvement in the alignment quality, which indicates this solution's feasibility. The fine-tuned model is not perfect, but further improvements could be made by (1) revising the prompt to produce more accurate scores, (2) increasing the number and the variety of the training prompts, (3) increasing the number of training epochs, (4) testing different LLMs as the reward model.

RyutaYoshimatsu_5-1709809191374.png

RyutaYoshimatsu_6-1709809191342.png

Samples of outputs generated by the pre-fine-tuning (left) and the post-fine-tuning (right) modelsSamples of outputs generated by the pre-fine-tuning (left) and the post-fine-tuning (right) models


Inference

After we evaluated the fine-tuned model, we deployed it behind a real-time endpoint and made it available for a downstream application. Databricks Model Serving provides optimized inference for large open-source models like Llama2. The deployment is straightforward using an API  (see this notebook) or UI. 

 

Wrap Up

Aligning LLM is a crucial topic when building production-ready models. A supervised fine-tuned model is often not sufficient, and we need to further tune it for our specific requirements. Reinforcement learning is often used, but this requires human labelers to rank millions of outputs, which is, for many companies, cost-prohibitive and operationally complex.

In this article, we proposed a solution that uses a pre-trained or fine-tuned LLM as a reward model that eliminates the manual labor of collecting millions of ranked outputs from human labelers. This solution will enable many companies previously struggling to justify the high cost of labeling and avoid getting into operational complications to align their LLMs efficiently.

 

Acknowledgment

We thank Hugging Face, especially the TRL team, for making tremendous contributions to this field. For this blog post, we borrowed a lot of content from their amazing TRL repository.