Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
🩺 Dr. GRPO loss #3256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🩺 Dr. GRPO loss #3256
Changes from 5 commits
281cc93
f93c9b0
94a8cdc
161f052
a32e12c
411adcc
68e3afe
f893e96
45f74a4
2b97bc3
b1c9ee0
4b8c9ff
eb499fa
aac380e
fe989f8
ca030bb
2d68d10
4be9e2d
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is
bnpo
? Would be good to have a reference to where it's defined (I thought we had DAPO as the default loss)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, I realized while doing this PR that it wasn't exactly DAPO that was being used, but a variant of BNPO as defined here :
Let me try to clarify here. Losses per token are normalized by
bnpo
in the code, but it's not 100% correct*a batch is made up of
num_devices * gradient_accumulations
local batchesThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Special cases:
When
per_device_batch_size==num_generations
, TRL's BNPO is equivalent to DAPOper_device_batch_size==1
, TRL's BNPO is equivalent to GRPOgradient_accumualtion_steps==1
andnum_devices=1
, TRL's BNPO is equivalent to the actual BNPO.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qgallouedec Thanks for the comprehensive support! A minor comment for your future consideration: Dr. GRPO does not constrain the constant normalizer to be MAX_LEN (although it's easier to just use that). This can affect the update scale (related to your recent tweet https://x.com/QGallouedec/status/1908741708021457357). In fact, different constant of
x
in the setting in your tweet can be absorbed into the constant normalizer we propose in the paper, and MAX_LEN is a convenient example.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, @edbeeching was trying something slightly different in #3231 that did local scaling per batch instead of a global constant. Do you know if there's much difference between the two?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are roughly equivalent, I have closed my PR in favor of this one.