DeepSeek-R1
Introduction
DeepSeek-R1 is an open-source model that is developed by a Chinese quant company called DeepSeek AI. This model has taken the AI community by storm as it is the first open-source solution capable of achieving performance comparable to premium OpenAI models (e.g., OpenAI-o1/o3) with a fraction of the training and inference costs. It is also entirely free to use under an MIT license.
Panic in Silicon Valley because of DeepSeek
It is not a joke that Silicon Valley but not the whole tech industry is panicking about DeepSeek. Forbes even has a Panic Live update on their website updating the latest loss of the stock market.
Nvidia’s stock price dropped by 17%, a drop of $589
billion in market cap - the biggest single-day loss in history (hint NVIDIA doesn’t like Test-time Computing).
And CEO of ScaleAI, a company that provides AI training data for LLMs models, who also doesn’t like the cost and data efficiency of DeepSeek, guessed that DeepSeek might has more GPU resources than they announced.
President Trump called it a “wake-up call” for U.S. industries.
Models Released
- DeepSeek-R1-Zero: This model, trained through large-scale reinforcement learning (RL) without supervised fine-tuning (SFT) as a preliminary step, demonstrates remarkable reasoning capabilities. Through RL, DeepSeek-R1-Zero naturally develops numerous powerful and intriguing reasoning behaviors. However, it encounters challenges such as poor readability and language mixing.
- DeepSeek-R1: Incorporating multi-stage training and cold-start data before RL, DeepSeek-R1 achieves performance comparable to OpenAI-o1-1217 on reasoning tasks.
- Distill-R1: A series of six dense models (1.5B, 7B, 8B, 14B, 32B, 70B) distilled from DeepSeek-R1 based on Qwen and Llama. Notably, the distilled 14B model outperforms state-of-the-art open-source models like Qwen-32B-Preview by a large margin. The 32B and 70B models set new records on reasoning benchmarks among dense models.
Research Questions
- Can language model reasoning capabilities be improved purely through reinforcement learning without supervised fine-tuning?
Key Story Line
-
Base Model: The team uses DeepSeek-V3-Base and employs Group Relative Policy Optimization (GRPO) as the RL framework to enhance reasoning performance.
-
Performance Gains: DeepSeek-R1-Zero achieves impressive reasoning benchmarks. For instance, the pass@1 score on AIME 2024 improves from 15.6% to 71.0%. With majority voting, the score further rises to 86.7%, matching OpenAI-o1-0912’s performance.
-
Challenges and Solutions: While RL-only training produces strong reasoning capabilities, it introduces issues such as poor readability and language mixing. DeepSeek-R1 addresses these by incorporating cold-start data and multi-stage training pipelines.
-
Pipeline Highlights:
- Collection of cold-start data for initial fine-tuning.
- Reasoning-oriented RL to refine reasoning skills.
- SFT using new datasets generated through rejection sampling and DeepSeek-V3 outputs.
- Final RL phase to align the model with human preferences across all scenarios.
Resources
- Paper: DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning
- Code: DeepSeek-R1
- All papers from DeepSeek-AI from Huggingface and DeepSeek’s Github
- Understanding Multi-Head Latent Attention from Eryk Banatt
Approach
Overview
- DeepSeek-R1-Zero: Applies RL directly to the base model without supervised fine-tuning. GRPO serves as the RL framework.
- DeepSeek-R1: Employs a multi-stage process combining RL and SFT to address readability and language issues while enhancing performance.
- Distill-R1: Features six dense models (1.5B, 7B, 8B, 14B, 32B, 70B) distilled from DeepSeek-R1 based on Qwen and Llama, setting new records in reasoning benchmarks.
DeepSeek-R1-Zero: RL on the Base Model
Group Relative Policy Optimization
How does GRPO differ from PPO?
Traiditional RL methods like PPO requires a pre-trained critic model to evaluate the performance of the policy model. However, to train a critic model, we need a pair of winning and losing outputs given a same input, normally from a human evaluator. These pairs are expensive to obtain, hard to scale. Moreover, if the task is complex, the human evaluator may be subjective, biased, or nuanced.
GRPO, on the other hand, removes the need for a pre-trained critic model by comparing responses within a group, therefore overcoming the above limitations of PPO.
GRPO Objective Function Specifically, for each question \(q\), GRPO samples a group of outputs \(\{o_1, o_2, \cdots, o_G\}\) from the old policy model \(\pi_{\theta_\text{old}}\). It then optimizes the policy model \(\pi_{\theta}\) by maximizing the following objective function:
\[\mathcal{J}_\text{GRPO}(\theta) = \mathbb{E}_{q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_{\theta_\text{old}}(O|q)} \left[ \frac{1}{G} \sum_{i=1}^G \left(\min\left(\frac{\pi_\theta(o_i|q)}{\pi_{\theta_\text{old}}(o_i|q)}A_i, \text{clip}\left(\frac{\pi_\theta(o_i|q)}{\pi_{\theta_\text{old}}(o_i|q)}, 1-\epsilon, 1+\epsilon\right)A_i\right) - \beta\mathbb{D}_\text{KL}(\pi_\theta||\pi_\text{ref})\right)\right]\]where the KL divergence term \(\mathbb{D}_\text{KL}\) is defined as:
\[\mathbb{D}_\text{KL}(\pi_\theta||\pi_\text{ref}) = \frac{\pi_\text{ref}(o_i|q)}{\pi_\theta(o_i|q)} - \log\frac{\pi_\text{ref}(o_i|q)}{\pi_\theta(o_i|q)} - 1\]and \(A_i\) is the advantage function defined as:
\[A_i = \frac{r_i - \text{mean}({r_1, r_2, \cdots, r_G})}{ \text{std}({r_1, r_2, \cdots, r_G})}\]where \(r_i\) is the reward of the output \(o_i\) and \(\text{mean}\) and \(\text{std}\) are the mean and standard deviation of the rewards in the group. The reward \(r_i\) is from a rule-based reward system (not from a human evaluator, therefore, it is scalable and might not be subjective).
The rule-based reward system mainly consists of two types of rewards:
- Accuracy rewards: evaluate whether the output is correct or not. There are plenty of existing datasets where the correct answer is known, for example, Math problems with deterministic answers or Leetcode problems with predefined test cases.
-
Format rewards: the output will be rewarded if it is in a predefined format. For example, the thinking process should be between
<think>
and</think>
.
Template for DeepSeek-R1-Zero:
A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within
and tags, respectively, i.e., reasoning process here answer here . User: prompt. Assistant:
Why is the rule-based reward system effective? To me, the employed of rule-based reward system is another example of how self-supervised learning - where data can be generated automatically and massively - is the source of the success of large-scale deep learning models. Similar to the success of ControlNet in image generation which also employs traditional CV techniques such as edge detection to create additional control signals, so that the model can leverage the existing rule-based knowledge in the dataset to improve its learning process, the rule-based reward system in this paper is a simple yet effective way that allows to create a large amount of data with structure/label, which is crucial for training a large-scale model, making the scaling law become still valid.
However, the rule-based reward system is not perfect and to my understanding, it is the reason why DeepSeek-R1-Zero encounters challenges such as poor readability, and language mixing.
Breaking down the GRPO objective function
The expectation term The expectation term \(\mathbb{E}_{q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_{\theta_\text{old}}(O \mid q)}\) says that for each question \(q\) sampled from a distribution of questions \(P(Q)\), we sample a group of outputs \(\{o_i\}_{i=1}^G\) from the old policy model \(\pi_{\theta_\text{old}}\).
The KL divergence term Minimizing the KL divergence term ensures that the policy model \(\pi_\theta\) does not deviate too much from the reference model \(\pi_\text{ref}\). Specifically, let \(t=\frac{\pi_\text{ref}(o_i \mid q)}{\pi_\theta(o_i \mid q)}\), then the KL divergence term can be rewritten as:
\[\mathbb{D}_\text{KL}(\pi_\theta||\pi_\text{ref}) = t - \log (t) - 1\]And then \(\mathbb{D}_\text{KL}(\pi_\theta \mid \mid \pi_\text{ref}) \geq 0 ; \forall t > 0\) and minima is 0 when \(t=1\).
The advantage function This term reflects how much better the output \(o_i\) is compared to the average output in the group, e.g., if \(A_i > 0\), then \(o_i\) is better than the average output in the group or if \(A_i > A_j\), then \(o_i\) is better than \(o_j\).
Therefore, maximizing the scaled advantage function \(\frac{\pi_\theta(o_i \mid q)}{\pi_{\theta_\text{old}}(o_i \mid q)}A_i\) encourages the policy model \(\pi_\theta\) to generate outputs that are better than the average output in the group, i.e., those with \(A_i > 0\) while discouraging the worse outputs, i.e., those with \(A_i < 0\).
Performance, Self-evolution Process and Aha Moment
As mentioned in Section 2.2.4 of the paper, the performance of DeepSeek-R1-Zero is evaluated on the AIME 2024 benchmark (see AIME 2024) and impressively reaching comparable performance to OpenAI-o1-0912 - a premium OpenAI reasoning model - on the pass@1 score. Additionally, the performance of DeepSeekR1-Zero can be further augmented through the application of majority voting. For example, when majority voting is employed on the AIME benchmark, DeepSeek-R1-Zero’s performance escalates from 71.0% to 86.7%, thereby exceeding the performance of OpenAI-o1-0912.
Self-evolution Process
Beside the impressive performance, DeepSeek-R1-Zero also exhibits a fascinating self-evolution process as shown in Figure 3 of the paper, where the average response length per question increases over training time (from several hundred tokens to 10k+ tokens), again, with RL only. DeepSeek-R1-Zero naturally acquires the ability to solve increasingly complex reasoning tasks by leveraging extended test-time computation (see Test time computing).
One of the most remarkable aspects of this self-evolution is the emergence of sophisticated behaviors as the test-time computation increases. Behaviors such as reflection—where the model revisits and reevaluates its previous steps—and the exploration of alternative approaches to problem-solving arise spontaneously. These behaviors are not explicitly programmed but instead emerge as a result of the model’s interaction with the reinforcement learning environment. This spontaneous development significantly enhances DeepSeek-R1-Zero’s reasoning capabilities, enabling it to tackle more challenging tasks with greater efficiency and accuracy.
Aha Moment
Another interesting phenomenon observed in DeepSeek-R1-Zero is the aha moment (of the model - as well as the authors or myself) where the model suddenly realizes that it needs to allocate more thinking time to solve the problem, by reevaluating its inital approach. This reminds me of another aha moment in the history of RL, when a DeepMind’s DQN model explored an insane strategy to win the Atari game Breakout with the least effort by simply digging a hole in the wall. Or DeepMind’s AlphaGo move 37 - the move that no human player would have ever made.
This behavior is not only a testament to the model’s growing reasoning abilities but also a captivating example of how reinforcement learning can lead to unexpected and sophisticated outcomes. It underscores the power and beauty of reinforcement learning: rather than explicitly teaching the model on how to solve a problem, we simply provide it with the right incentives, and it autonomously develops advanced problem-solving strategies.
DeepSeek-R1 - RL with Cold Start
While DeepSeek-R1-Zero’s performance is impressive, it still encounters challenges such as poor readability, and language mixing. To address these issues and further enhance reasoning performance, the authors introduce DeepSeek-R1, which incorporates a small amount of cold-start data and a four-stage training pipeline.
Cold Start with CoT data
Unlike DeepSeek-R1-Zero, which begins with pure RL on the base model, DeepSeek-R1 incorporates a cold start phase. This stage involves collecting thousands of long Chain-of-Thought (CoT) data to fine-tune the base model (DeepSeek-V3-Base). This data is generated using methods such as few-shot prompting, direct prompting with reflection and verification, gathering DeepSeek-R1-Zero outputs, and refining with human annotators. The purpose of this step is to prevent an unstable start in the RL process and ensure the model produces more readable and coherent responses. The output format is designed to include a summary at the end of each response: |special_token|<reasoning_process>|special_token|<summary>
.
Reasoning-oriented RL
After the cold start fine-tuning, the model undergoes a reasoning-oriented RL training process, which is similar to the one used for DeepSeek-R1-Zero. This stage focuses on enhancing the model’s ability to handle tasks in areas such as coding, mathematics, science, and logic. A language consistency reward is added during RL training, calculated as the proportion of target language words in the CoT, to mitigate language mixing issues, though this may slightly degrade performance. The final reward is a combination of reasoning accuracy and language consistency. The Group Relative Policy Optimization (GRPO) algorithm is employed for this stage, as mentioned in our previous conversation, to optimize the policy model, reduce training costs and estimate the baseline from group scores.
SFT with new data
Once the reasoning-oriented RL has converged, the resulting checkpoint is used to collect SFT data for the next round. This stage incorporates both reasoning data and non-reasoning data. Rejection sampling, as discussed in our earlier conversation, is used to generate reasoning trajectories from the model’s output. The model is prompted to generate multiple responses, and only the correct and coherent responses are kept, and used as SFT data. This is also where a generative reward model is used, feeding both ground-truth and model predictions into DeepSeek-V3 for judgment. Non-reasoning data such as writing, factual QA, self-cognition, and translation, are added by adopting the DeepSeek-V3 pipeline and reusing portions of the DeepSeek-V3 SFT dataset. The DeepSeek-V3-Base model is then fine-tuned using this combined dataset
RL with all scenarios
The final stage consists of a secondary RL process to align the model with human preferences. This stage aims to improve the model’s helpfulness and harmlessness while refining its reasoning skills. For reasoning data, the process is similar to DeepSeek-R1-Zero, utilizing rule-based rewards. For general data, reward models are used to capture human preferences, where final summaries are assessed for helpfulness, while the entire response (including reasoning and summary) is evaluated for harmlessness
Conclusion
To me, the most interesting part of this paper is the invention of DeepSeek-R1-Zero, whose introduction has had a profound impact on our understanding of RL and LLM training. More specifically, pure RL with rule-based rewards might represent a new paradigm for LLM training. The use of a rule-based reward system strikes me as another example of how self-supervised learning—where data can be generated automatically and on a massive scale—continues to underpin the success of large-scale deep learning models.
Similar to the success of ControlNet in image generation, which leverages traditional computer vision techniques like edge detection to provide additional control signals, the rule-based reward system in this paper offers a simple yet effective method to generate large amounts of structured, labeled data. This, in turn, plays a crucial role in training large-scale models, ensuring that the scaling laws remain valid.
The aha moment in DeepSeek-R1-Zero perfectly encapsulates the elegance and power of reinforcement learning: instead of explicitly teaching the model how to solve a problem, we simply design the right incentives, allowing the model to autonomously develop sophisticated problem-solving strategies.
Appendix
AIME 2024
The American Invitational Mathematics Examination (AIME) is a prestigious mathematics competition in the United States, serving as an intermediary between the AMC 10/12 exams and the USA Mathematical Olympiad (USAMO). The AIME consists of 15 questions, each with an integer answer between 0 and 999, to be completed in 3 hours. Participants qualify for the AIME based on their performance in the AMC 10 or AMC 12 exams.
In 2024, the AIME I was administered on January 31, and the AIME II on February 7. The mean score for AIME I was 5.89, with a median of 5, while AIME II had a mean score of 5.45 and a median of 5.
The AIME 2024 benchmark employs two metrics:
- The
pass@1
score means the percentage of the questions that the model can solve correctly with the top-1 response (see Evaluation Setup - page 12). - The
cons@64
score means the consensus (majority voting) result of the top-64 responses.
Rejection Sampling
Purpose: Rejection sampling is employed to generate reasoning trajectories from the model’s checkpoint after reasoning-oriented reinforcement learning (RL) has converged. The goal is to create a dataset that can improve the model’s ability in various areas, including writing, role-playing, and other general-purpose tasks, alongside its reasoning capabilities.
Process:
- A set of reasoning prompts are curated.
- The model generates multiple responses for each prompt.
- Only correct responses are retained, while incorrect or less desirable responses are rejected. This filtering step ensures that the SFT data consists of high-quality examples.
- The responses are also filtered to remove issues like mixed languages, long paragraphs, and code blocks, to ensure readability and relevance.
Expansion of Dataset: In the rejection sampling stage, the dataset expands beyond those that can be evaluated using rule-based rewards by including data that use a generative reward model. This is done by feeding the ground-truth and model predictions into DeepSeek-V3 for judgment.
Output Quality: The overall goal is to produce higher quality training samples. This is done by filtering out low-quality responses and ensures that the model trains on consistent and reliable data.
In summary, rejection sampling plays a crucial role in the DeepSeek-R1 pipeline by generating a refined and expanded dataset for the second round of supervised fine-tuning. This process contributes to enhancing the model’s overall capabilities.
Test time computing
Test Time Computing (TTC) refers to computational processes performed during the inference phase of machine learning models—that is, when the model is used to make predictions or solve problems after being trained. Unlike traditional inference, which usually involves a straightforward application of a pre-trained model, TTC allows for additional computations or adjustments to improve performance on specific tasks.
Key Concepts in Test Time Computing:
- Adaptation at Inference: Some models dynamically adapt their behavior based on new inputs or environmental conditions. This can involve fine-tuning parts of the model or leveraging meta-learning techniques.
- Iterative Reasoning: Instead of producing a single output, models perform multiple reasoning steps (e.g., generating intermediate explanations or calculations) to refine their predictions. This is common in large language models when solving complex problems.
- On-the-Fly Learning: The model might use previously unseen data to improve its predictions in real time. This is particularly useful in tasks like personalization or domain adaptation.
- Resource Allocation: TTC allows models to allocate varying amounts of computational resources to different inputs, depending on task complexity or uncertainty. For example, a model may run deeper reasoning loops for harder questions.
-
Applications:
- Natural Language Processing (NLP): Iterative reasoning to solve logic or math problems.
- Computer Vision: Adjusting filters or segmentations for specific images.
- Personalization: Adapting user recommendations based on recent interactions.
- Robotics: Dynamically adjusting movements based on environmental feedback.
Benefits:
- Improved Accuracy: By refining outputs at test time, models often achieve higher performance on difficult tasks.
- Task-Specific Customization: Allows models to handle nuanced problems more effectively.
- Efficient Use of Resources: Computational effort can be adjusted based on task complexity.
Challenges:
- Increased Latency: Additional computations can slow down predictions.
- Higher Costs: Real-time adjustments require more computational resources.
- Complexity: Implementing TTC mechanisms can complicate model architecture.
This approach is increasingly used in advanced AI systems, such as OpenAI’s GPT models, which employ techniques like iterative reasoning or chain-of-thought prompting to tackle complex tasks effectively.
Why NVIDIA doesn’t like TTC
NVIDIA’s GPUs are designed for parallel computing, which is not suitable for TTC which often involves sequential or iterative computation for individual inputs, underutilizing the GPU’s parallel architecture. TTC introduces variability and possibly higher latency, which isn’t ideal for traditional GPU pipelines.
Monte Carlo Tree Search
Monte Carlo Tree Search (MCTS) is an advanced search algorithm used primarily in decision-making processes, especially for games, simulations, and optimization problems. It is a method for making decisions by simulating many possible outcomes and using statistical analysis to find the most promising path.
Key Components of MCTS MCTS works by iteratively building a search tree, where nodes represent game states (or decision points) and edges represent actions. The process involves four main steps:
1.Selection
- Starting from the root node, the algorithm selects child nodes recursively until it reaches a node that is not fully expanded (i.e., not all possible moves are explored).
- The selection is often guided by a strategy like the Upper Confidence Bound for Trees (UCT), which balances exploration (trying less-visited nodes) and exploitation (focusing on nodes with high average rewards):
2.Expansion
- When a leaf node is reached, new child nodes are added for all possible moves from the current state.
- This step grows the search tree by exploring unvisited nodes.
3.Simulation (Rollout)
- From the newly added node, a simulation is run to the end of the game (or a predefined depth). The simulation often involves random or heuristic-based moves.
- The outcome (e.g., win, loss, or score) of this rollout provides an estimate of the value of the node.
4.Backpropagation
- The result of the simulation is propagated back up the tree, updating the statistics (e.g., win rate or average reward) for each node along the path to the root.
- This helps the algorithm prioritize the most promising branches in future iterations.
Applications of MCTS
1.Games:
- Widely used in game-playing AI, especially for games with large decision spaces (e.g., Go, Chess, Poker).
- Integral to the success of systems like AlphaGo, which combined MCTS with deep neural networks.
2.Robotics and Planning:
- Used to plan sequences of actions in dynamic environments where outcomes are uncertain.
3.Optimization:
- Applied in optimization problems where exploring the solution space is challenging due to its complexity or size.
4.Simulations:
- Used in Monte Carlo simulations to estimate probabilities or solve probabilistic decision-making problems.
Strengths of MCTS
- Scalable: Handles very large state spaces effectively.
- Adaptive: Focuses computational resources on the most promising parts of the tree.
- Flexible: Can work without a full model of the game or problem and adapt as new information is added.
Limitations
- Computationally Expensive: Requires many simulations, especially for complex problems.
- Dependence on Rollout Policy: The quality of results depends heavily on how the simulations (rollouts) are performed.
- Suboptimal for Short Decision Horizons: Less effective for problems requiring quick, shallow decisions.
MCTS combines principles from reinforcement learning, probability, and decision-making, making it a powerful tool for complex tasks that involve uncertainty and large decision spaces.
Reinforcement Learning with Human Feedback (RLHF)
Goal: Fine-tuning a model using reinforcement learning where the reward signal is derived from human feedback to align the model with human preferences and values beyond task-specific objectives.
Process:
- Supervised Pretraining: Start with a pretrained and optionally fine-tuned model.
- Feedback Collection: Collect human-labeled data ranking model outputs (e.g., ranking completions for prompts).
- Reward Model (RM): Train a reward model to predict rankings based on human feedback. The reward model is trained to mimic human preferences on a specific task.
- Policy Optimization: Fine-tune the model (policy) using reinforcement learning (e.g., Proximal Policy Optimization, PPO) to maximize the reward from the RM.
Loss Function:
- Combines reinforcement learning objectives (e.g., PPO loss) with supervised objectives to balance exploration and alignment.
Advantages:
- Aligns model behavior with human values, preferences, and ethical considerations.
- Reduces harmful or inappropriate responses.
- Improves usability in real-world scenarios (e.g., chatbot interactions).
Limitations:
- Requires expensive and time-consuming human feedback.
- May introduce biases based on the preferences of the feedback providers.
- Balancing alignment with generalization is challenging.
Comparison Table
Aspect | Unsupervised Pretraining | Supervised Fine-Tuning | RLHF |
---|---|---|---|
Purpose | General-purpose language understanding | Task-specific performance improvement | Aligning outputs with human preferences |
Data Requirement | Large-scale unlabeled text corpora | Labeled datasets for specific tasks | Human-labeled feedback or rankings |
Objective | Learn language patterns and knowledge | Optimize for specific task objectives | Optimize for human alignment |
Training Cost | High (large datasets, long training) | Moderate (smaller labeled datasets) | Very high (feedback collection + RL tuning) |
Model Usage | Provides a base model | Task-specific models | Refines models for safer, more useful output |
Challenges | Task-agnostic, needs fine-tuning | Dependent on labeled data quality | Expensive and subject to bias in feedback |
Example | Training GPT, BERT, T5 from scratch | Fine-tuning BERT for sentiment analysis | Fine-tuning GPT with human ranking data |
PPO and DPO
Resources:
Proximal Policy Optimization (PPO)
The Bradley-Terry model is a probabilistic model used to rank items based on pairwise comparisons.
\[p(y_1 \succ y_2 \mid x) = \frac{exp(r(x,y_1))}{exp(r(x,y_1)) + exp(r(x,y_2))}\]where \(p\) is the probability of \(y_1\) being better than \(y_2\) given \(x\) representing the true human preferences and \(r\) is the reward function. \(y_1\) and \(y_2\) are the two items/responses being compared given \(x\) representing the input.
If \(r^*(x,y_1) > r^*(x,y_2)\), then \(p^*(y_1 \succ y_2 \mid x) > 0.5\), which means \(y_1\) is more likely to be better than \(y_2\) given \(x\).
To learn the reward function \(r\) from the data, we parameterize as a neural network \(r_{\phi}\) and optimize the following objective function:
\[\mathcal{L}_{R}(r_{\phi}, \mathcal{D}) = - \mathbb{E}_{(x,y_w,y_l) \sim \mathcal{D}} \left[ \log \sigma (exp(r_{\phi}(x,y_w)) - exp(r_{\phi}(x,y_l))) \right]\]where \(\mathcal{D}\) is the dataset of human preferences and \(y_w\) and \(y_l\) are the winning and losing responses against the same input \(x\). Minimizing the above objective function is equivalent to maximizing the probability of the winning response being better than the losing response given the input.
Note that the reward function \(r_{\phi}\) is usually initialized from the supervised fine-tuning (SFT) model same as the policy model.
Fine-tuning the policy model
Once we have the reward model \(r_{\phi}\), we can fine-tune the policy model \(\theta\) using the following objective function:
\[\mathcal{L}_{P}(\theta, \mathcal{D}) = - \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_{\theta}(y \mid x)} \left[ r_{\phi}(x,y) + \beta D_{KL}(\pi_{\theta}(y \mid x) || \pi_{ref}(y \mid x)) \right]\]where \(\pi_{ref}\) is the reference model and \(\beta\) is a hyperparameter. Minimizing the first term enforces the policy model to generate responses that are more preferred by the reward model, while minimizing the second term ensures that the policy model does not deviate too much from the reference model.
Advantages:
- Aligns the model with human preferences.
- Reduces harmful or inappropriate responses.
- Improves usability in real-world scenarios (e.g., chatbot interactions).
Limitations:
- Requires expensive and time-consuming human feedback.
- May introduce biases based on the preferences of the feedback providers.
- Balancing alignment with generalization is challenging.
Direct Preference Optimization (DPO)
DPO simplifies PPO by directly optimizing the policy model to adhere to human preferences without the need for a reward model.
Objective Function:
\[\pi_{r}(y \mid x) = \frac{1}{Z(x)} \pi_{\text{ref}}(y \mid x) \exp(r(x,y)/\beta)\]where \(\pi_r\) is the policy model, \(\pi_{\text{ref}}\) is the reference model, \(r\) is the reward function, and \(\beta\) is a hyperparameter.
\[r(x,y) = \beta \log \frac{\pi_r(y \mid x)}{\pi_{\text{ref}}(y \mid x)} + \log Z(x)\]where \(Z(x)\) is the partition function.
The final objective function is:
\[\mathcal{L}_{DPO}(\theta, \mathcal{D}) = - \mathbb{E}_{(x,y_w,y_l) \sim \mathcal{D}} \left[ \log \sigma (\pi_{\theta}(x,y_w) - \pi_{\theta}(x,y_l)) \right]\]Where there is no need for the reward model \(r_{\phi}\) as the policy model \(\pi_{\theta}\) is directly optimized to adhere to human preferences. Minimizing the above objective function is directly maximizing the probability of the winning response while minimizing the probability of the losing response.
SFT vs RLHF
Reinforcement Learning from Human Feedback (RLHF) is a technique used to align large language models (LLMs) more closely with human preferences and expectations. While supervised fine-tuning (SFT) is a critical step in improving a model’s capabilities, it has limitations that RLHF can address.
When is RLHF needed?
Ambiguity in Objectives
- SFT aligns the model with a dataset of “correct” outputs, but human preferences often involve subjective judgment or context-specific nuance that cannot be fully captured in a static dataset.
- Example: Chatbots generating empathetic or polite responses where the tone and context matter significantly.
Unclear or Complex Evaluation Metrics
- For some tasks, it’s difficult to define explicit evaluation metrics, but humans can intuitively judge quality.
- Example: Creative writing or generating humorous content.
Long-Term or Multi-Step Reasoning
- SFT trains models to produce correct outputs based on immediate context but might fail in scenarios requiring multi-step decision-making or long-term coherence.
- Example: Writing a coherent multi-paragraph essay or guiding a user through a series of troubleshooting steps.
Avoiding Harmful Outputs
- Static datasets used for SFT may not include all edge cases or potential pitfalls, and RLHF can help refine the model to avoid harmful or toxic outputs based on human preferences.
- Example: Ensuring a conversational agent avoids generating offensive or biased content.
Improving User Experience
- SFT often focuses on correctness, but RLHF can optimize for user satisfaction, such as balancing informativeness, politeness, and conciseness.
- Example: Personal assistants generating concise and helpful responses tailored to user needs.
Why is Supervised Fine-Tuning Not Enough?
Static Nature of Datasets
- SFT relies on pre-collected datasets that might not represent all real-world scenarios or evolving user preferences.
- Limitation: If the dataset lacks certain examples, the model cannot generalize well.
Difficulty in Capturing Preferences
- Human preferences are often complex and not directly labeled in datasets.
- Limitation: A model trained on SFT might produce technically correct but undesirable outputs (e.g., overly verbose or lacking empathy).
Overfitting to Training Data
- SFT can lead to overfitting, where the model learns to replicate the training data without adapting to unseen scenarios.
- Limitation: This can result in poor performance on out-of-distribution examples.
Reward Optimization
- RLHF optimizes a reward function designed to capture human preferences, which allows fine-grained control over the model’s behavior.
- Limitation: SFT does not involve direct optimization based on human evaluations.
Enjoy Reading This Article?
Here are some more articles you might like to read next: