Skip to content

🩺 Dr. GRPO loss #3256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 9, 2025
Merged
40 changes: 35 additions & 5 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ This approach gives the method its name: **Group Relative Policy Optimization (G

<Tip>

It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf) that scaling by \\( \text{std}(\mathbf{r}) \\) may cause a question-level difficulty bias. You can disable this scaling by setting `scale_rewards=False` in [`GRPOConfig`].
It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.14476) that scaling by \\( \text{std}(\mathbf{r}) \\) may cause a question-level difficulty bias. You can disable this scaling by setting `scale_rewards=False` in [`GRPOConfig`].

</Tip>

Expand All @@ -92,26 +92,56 @@ $$
The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows:

$$
\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{G} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
$$

where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.

<Tip>

Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we don't scale by \\( \frac{1}{|o_i|} \\) because it was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf) that this introduces a response-level length bias.
Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we don't scale by \\( \frac{1}{|o_i|} \\) because it was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.14476) that this introduces a response-level length bias. More details in [loss types](#loss-types).

</Tip>

In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**:

$$
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
$$

where \\(\text{clip}(\cdot, 1 - \epsilon, 1 + \epsilon) \\) ensures that updates do not deviate excessively from the reference policy by bounding the policy ratio between \\( 1 - \epsilon \\) and \\( 1 + \epsilon \\).
When \\( \mu = 1 \\) (default in TRL), the clipped surrogate objective simplifies to the original objective.

#### Loss Types

Several formulations of the objective have been proposed in the literature. Initially, the objective of GRPO was defined as follows:

$$
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} l_{i,t},
$$

where

$$
l_{i,t} = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right].
$$

The DAPO paper highlights the limitations of the GRPO algorithm’s sample-level loss in long-CoT scenarios, where longer responses are under-penalized, leading to poorer quality outputs. The proposed solution is a Token-level Policy Gradient Loss, which better handles longer sequences by assigning more precise rewards to individual tokens, regardless of response length:

$$
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
$$

This formulation is used by default in TRL. To further reduce bias, normalization is done by the total number of active tokens in the batch, rather than within each group.

Furthermore, it was demonstrated in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.14476) that the initial GRPO formulation introduces a response length bias. They show that while the DAPO formulation reduces this bias, it does not eliminate it completely. To fully remove this bias, they propose dividing by a constant instead of the sequence length, resulting in the following formulation:

$$
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{LG} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
$$

This constant is recommended to be the maximum completion length. To use this formulation, set `loss_type="drgrpo"` in the [`GRPOConfig`].

## Logged metrics

- `num_tokens`: The total number of tokens processed so far, including both prompts and completions.
Expand All @@ -121,7 +151,7 @@ When \\( \mu = 1 \\) (default in TRL), the clipped surrogate objective simplifi
- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS.
- `completions/min_terminated_length`: The minimun length of generated completions that terminate with EOS.
- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS.
- `completions/clipped_ratio` : The ratio of truncated (clipped) completions.
- `completions/clipped_ratio` : The ratio of truncated (clipped) completions.
- `reward/{reward_func_name}/mean`: The average reward from a specific reward function.
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
- `reward`: The overall average reward after applying reward weights.
Expand Down
32 changes: 32 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,38 @@ def test_training(self, config_name):
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

@parameterized.expand([("bnpo",), ("drgrpo",)])
def test_training_loss_types(self, loss_type):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
loss_type=loss_type,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_training_with_eval(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")

Expand Down
27 changes: 24 additions & 3 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,19 @@ class GRPOConfig(TrainingArguments):
scale_rewards (`bool`, *optional*, defaults to `True`):
Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), the rewards
are normalized by the standard deviation, ensuring they have unit variance. If `False`, no scaling is
applied. The [Dr. GRPO](https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf)
paper recommends not scaling the rewards, as scaling by the standard deviation introduces a question-level
difficulty bias.
applied. The [Dr. GRPO paper](https://huggingface.co/papers/2503.14476) recommends not scaling the rewards,
as scaling by the standard deviation introduces a question-level difficulty bias.
loss_type (`str`, *optional*, defaults to `"bnpo"`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is bnpo? Would be good to have a reference to where it's defined (I thought we had DAPO as the default loss)

Copy link
Member Author

@qgallouedec qgallouedec Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, I realized while doing this PR that it wasn't exactly DAPO that was being used, but a variant of BNPO as defined here :

Screenshot 2025-04-08 at 06 27 50

Let me try to clarify here. Losses per token are normalized by

  • GRPO: the length of the sequence
  • DAPO: the average sequence length in the group
  • BNPO: the average sequence length in the batch
  • TRL's BNPO: the average sequence length in the local batch*; this is what I call bnpo in the code, but it's not 100% correct
  • Dr GRPO: by the maximum possible length of the completion

*a batch is made up of num_devices * gradient_accumulations local batches

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Special cases:
When

  • per_device_batch_size==num_generations, TRL's BNPO is equivalent to DAPO
  • per_device_batch_size==1, TRL's BNPO is equivalent to GRPO
  • gradient_accumualtion_steps==1 and num_devices=1, TRL's BNPO is equivalent to the actual BNPO.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qgallouedec Thanks for the comprehensive support! A minor comment for your future consideration: Dr. GRPO does not constrain the constant normalizer to be MAX_LEN (although it's easier to just use that). This can affect the update scale (related to your recent tweet https://x.com/QGallouedec/status/1908741708021457357). In fact, different constant of x in the setting in your tweet can be absorbed into the constant normalizer we propose in the paper, and MAX_LEN is a convenient example.

Specifies the loss formulation to use. Supported values are:

- `"bnpo"`: Token-level losses are aggregated by normalizing with the completion length
within the local batch. Note that normalization is performed over the local batch only, so results may
slightly vary depending on the local batch size, despite a constant effective batch size.
- `"drgrpo"`: Token-level losses are aggregated by normalizing with a global constant. This method was
introduced in the [Dr. GRPO paper](https://huggingface.co/papers/2503.14476) to eliminate length bias.
The value of the constant corresponds to `max_completion_length`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, @edbeeching was trying something slightly different in #3231 that did local scaling per batch instead of a global constant. Do you know if there's much difference between the two?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are roughly equivalent, I have closed my PR in favor of this one.


The original GRPO loss is not supported due to severe length bias that favors short completions.
sync_ref_model (`bool`, *optional*, defaults to `False`):
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
Expand Down Expand Up @@ -302,6 +312,17 @@ class GRPOConfig(TrainingArguments):
"deviation introduces a question-level difficulty bias."
},
)
loss_type: str = field(
default="bnpo",
metadata={
"help": "Specifies the loss formulation to use. Supported values are `bnpo` and `drgrpo`. `'bnpo'`: "
"Token-level losses are aggregated by normalizing with the completion length within the local batch. Note "
"that normalization is performed over the local batch only, so results may slightly vary depending on the "
"local batch size, despite a constant effective batch size.`'drgrpo'`: Token-level losses are aggregated "
"by normalizing with a global constant. This method was introduced in the Dr. GRPO paper to eliminate "
"length bias. The value of the constant corresponds to `max_completion_length`."
},
)
sync_ref_model: bool = field(
default=False,
metadata={
Expand Down
19 changes: 14 additions & 5 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@ def data_collator(features): # No data collation is needed in GRPO
self.repetition_penalty = args.repetition_penalty
self.use_vllm = args.use_vllm
self.use_liger_loss = args.use_liger_loss
self.loss_type = args.loss_type
self.scale_rewards = args.scale_rewards

# Datasets
if (
Expand Down Expand Up @@ -472,6 +474,7 @@ def data_collator(features): # No data collation is needed in GRPO
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
self._total_train_tokens = 0
self.log_completions = args.log_completions
self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
self.num_completions_to_print = args.num_completions_to_print
# maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the
# final optimization step.
Expand Down Expand Up @@ -749,7 +752,7 @@ def _generate_and_score_completions(
prompt_mask = prompt_mask[:, -self.max_prompt_length :]

# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
if self.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm()
Expand Down Expand Up @@ -906,7 +909,7 @@ def _generate_and_score_completions(
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = rewards - mean_grouped_rewards
if self.args.scale_rewards:
if self.scale_rewards:
advantages = advantages / (std_grouped_rewards + 1e-4)

# Slice to keep only the local part of the data
Expand Down Expand Up @@ -1047,7 +1050,13 @@ def _compute_loss(self, model, inputs):
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if self.beta != 0.0:
per_token_loss = per_token_loss + self.beta * per_token_kl
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()

if self.loss_type == "bnpo":
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
elif self.loss_type == "drgrpo":
loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
else:
raise ValueError(f"Unknown loss type: {self.loss_type}")

# Log the metrics
mode = "eval" if self.control.should_evaluate else "train"
Expand Down Expand Up @@ -1088,7 +1097,7 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
super().log(logs)
self._metrics[mode].clear()

if self.accelerator.is_main_process:
if self.accelerator.is_main_process and self.log_completions:
if is_rich_available():
print_prompt_completions_sample(
self._textual_logs["prompt"],
Expand All @@ -1108,7 +1117,7 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
**self._textual_logs["rewards"],
}
df = pd.DataFrame(table)
if self.args.wandb_log_unique_prompts:
if self.wandb_log_unique_prompts:
df = df.drop_duplicates(subset=["prompt"])
wandb.log({"completions": wandb.Table(dataframe=df)})

Expand Down
Loading