-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix decoding score comparison when using logits processors or warpers #10638
Fix decoding score comparison when using logits processors or warpers #10638
Conversation
The failing test is I honestly don't know what's the expected behavior there, so not sure if it's flaky or not. The weird thing is that this test seems to be greedy search, not beam search. |
Actually, I just looked more closely and the failing test does use beam search (the beam size is specified in the config). This is an example of something that changes since it uses a |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
I'm gonna address it, it's been in my mind. Please don't mark it as stale! |
I've added the |
@patrickvonplaten sorry for the big delay. I changed the normalization to be a logit warper now. What do you think of it, and its documentation? Also, what if we set a deprecation for it? And take advantage of some breaking change in the future and make it the default? |
The failing tests are flaky, right? |
Could we add one tests for the new logits processor as well? :-) |
3122f99
to
066b52c
Compare
35a072a
to
b9bd8a1
Compare
@patrickvonplaten can you remove the WIP label? This should be done. Also, the latest time a test failed, it seemed to be flaky. It should be good to go 🚀 |
The documentation is not available anymore as the PR was closed or merged. |
@patrickvonplaten friendly reminder on this! |
Also, should we add a flag in |
PR looks good to go for me - thanks @bryant1410. Yes indeed could you maybe add a flag |
Okay, @patrickvonplaten I did this change. What do you think about also making |
Oh, and btw, note I also applied it to the warpers (so it's applied to both the processors and warpers). |
Should the attribute be added to the configs such that the following can be applied? renormalize_logits if renormalize_logits is not None else self.config.renormalize_logits |
No need for this I think since it's quite a specific logit processor |
@bryant1410, could you also update RAG's generate method to incorporate you changes? The test currently fails with It should be easy to adapt here:
|
Done. What about this?
|
Good for merge for me! Let's see what @gante says |
Okay! What about the comment/idea on making it |
Don't really think that's possible due to backwards breaking changes tbh |
I understand. However, eventually, the breaking change is gonna happen because of some accumulated "debt" that gets big enough, after many different fixes or wanted features. Like it happens in other libraries. It could happen after some major version change (e.g., v5), which it's a great opportunity to change a lot of desired changes that are breaking. One approach to track this is to deprecate the value and say when it's gonna be changed (e.g., v5). It could be with a warning, some comment in the docstring, or maybe just a doc that tracks down which is gonna be changed. I guess what I'm saying is to add this change to that list (is it worth it, in your opinion?). BTW, do you have in this repo such a list of things that are eventually gonna be changed (maybe implicitly tracked in various comments)? What are your thoughts? Maybe you think differently? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rationale is sensible and I'm in favor of the approved changes 👍
To ensure this change stays future-proof, I'd like to discuss an additional change. The new logit processor, when it exists in the list of logit processors to be applied, must be the last one. Should we raise an exception when it isn't? (e.g. it has to be the last one in this list, when it exists) cc @patrickvonplaten
Makes sense to me. However, what if the user wants to do something custom, by manually adding this processor logit somewhere? If we add a check and an exception, then the user would face it in this custom scenario. Or maybe it's a bit far-fetched? |
Uhmm I see. We can go with the low effort, low cost, and low consequence alternative (see the following suggestion) |
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
@bryant1410 regarding the Since there are no other outstanding requests and CI is green, I'm merging the PR 💪 |
…huggingface#10638) * Normalize using a logits warper * Add a flag in `generate` to support the logit renormalization * Add in RAG
When doing beam search or other decoding search strategies, the logit scores are normalized (with
log_softmax
) so the comparisons between the beams (hypotheses) are meaningful. However, the logit processors or warpers may change the scores, and thus may not be normalized anymore.For example, say you have a beam size of 2. During beam search at some point, beam A is better than B (higher score). You use
prefix_allowed_tokens_fn
, which in turn through a logit processor narrows down the options of the next tokens to only one. Then masks out all tokens with-inf
but one. The score vector may look like[-inf, ..., -2.13, ..., -inf]
. This is output and now the scores are not normalized anymore. This filter is not applied to B. Now beam search selects B, which actually keeping the hypothesis A meant having the same probability since the normalized vector should have been[-inf, ..., 0, ..., -inf]
. In that case, hypothesis A would have been kept (and that's what actually should happen). This erroneous behavior can happen with any logit processor that doesn't normalize its output, which I see it's often the case.So that's why I moved the
log_softmax
to after the logit processor/warper application. I also checked if any logit processor needed the normalization for its input. It doesn't seem to be the case (though I'm not 100% sure). They can still individually apply a normalization if they need to. Maybe the documentation could be changed, by the way:transformers/src/transformers/generation_logits_process.py
Lines 37 to 39 in 26a33cf
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
I feel I should tag @patrickvonplaten, @patil-suraj