Skip to content

Commit

Permalink
Fix decoding score comparison when using logits processors or warpers (
Browse files Browse the repository at this point in the history
…huggingface#10638)

* Normalize using a logits warper

* Add a flag in `generate` to support the logit renormalization

* Add in RAG
  • Loading branch information
bryant1410 authored and elusenji committed Jun 12, 2022
1 parent d23dffc commit a0a9270
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 2 deletions.
13 changes: 13 additions & 0 deletions src/transformers/generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,3 +679,16 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Float
self.regulation_factor, cur_len - self.regulation_start
)
return scores


class LogitNormalization(LogitsProcessor, LogitsWarper):
r"""
[`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
the scores are normalized when comparing the hypotheses.
"""

def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores = scores.log_softmax(dim=-1)
return scores
29 changes: 27 additions & 2 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
Expand Down Expand Up @@ -636,6 +637,7 @@ def _get_logits_warper(
typical_p: Optional[float] = None,
temperature: Optional[float] = None,
num_beams: Optional[int] = None,
renormalize_logits: Optional[bool] = None,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
Expand All @@ -660,6 +662,9 @@ def _get_logits_warper(
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if typical_p is not None and typical_p < 1.0:
warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
# `LogitNormalization` should always be the last logit processor, when present
if renormalize_logits is True:
warpers.append(LogitNormalization())
return warpers

def _get_logits_processor(
Expand All @@ -682,6 +687,7 @@ def _get_logits_processor(
remove_invalid_values: bool,
exponential_decay_length_penalty: Tuple,
logits_processor: Optional[LogitsProcessorList],
renormalize_logits: Optional[bool],
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
Expand Down Expand Up @@ -754,6 +760,9 @@ def _get_logits_processor(
ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
)
processors = self._merge_criteria_processor_list(processors, logits_processor)
# `LogitNormalization` should always be the last logit processor, when present
if renormalize_logits is True:
processors.append(LogitNormalization())
return processors

def _get_stopping_criteria(
Expand Down Expand Up @@ -858,6 +867,7 @@ def generate(
diversity_penalty: Optional[float] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
renormalize_logits: Optional[bool] = None,
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
constraints: Optional[List[Constraint]] = None,
output_attentions: Optional[bool] = None,
Expand Down Expand Up @@ -986,6 +996,10 @@ def generate(
Custom logits processors that complement the default logits processors built from arguments and a
model's config. If a logit processor is passed that is already created with the arguments or a model's
config an error is thrown. This feature is intended for advanced users.
renormalize_logits: (`bool`, *optional*, defaults to `False`):
Whether to renormalize the logits after applying all the logits processors or warpers (including the
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
score logits are normalized but some logit processors or warpers break the normalization.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
model's config. If a stopping criteria is passed that is already created with the arguments or a
Expand Down Expand Up @@ -1241,6 +1255,7 @@ def generate(
remove_invalid_values=remove_invalid_values,
exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor,
renormalize_logits=renormalize_logits,
)

# 8. prepare stopping criteria
Expand Down Expand Up @@ -1271,7 +1286,12 @@ def generate(
elif is_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
num_beams=num_beams,
renormalize_logits=renormalize_logits,
)

# 11. expand input_ids with `num_return_sequences` additional sequences per batch
Expand Down Expand Up @@ -1333,7 +1353,12 @@ def generate(
elif is_beam_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
num_beams=num_beams,
renormalize_logits=renormalize_logits,
)

if stopping_criteria.max_length is None:
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/rag/modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,7 @@ def generate(
n_docs: Optional[int] = None,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
renormalize_logits: Optional[bool] = None,
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
Expand Down Expand Up @@ -1624,6 +1625,7 @@ def extend_enc_output(tensor, num_beams=None):
remove_invalid_values=remove_invalid_values,
exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor,
renormalize_logits=renormalize_logits,
)

if num_beams == 1:
Expand Down
16 changes: 16 additions & 0 deletions tests/generation/test_generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
Expand Down Expand Up @@ -537,3 +538,18 @@ def test_exponential_decay_length_penalty(self):
scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id]
).all()
)

def test_normalization(self):
input_ids = None

scores = torch.tensor(
[[-23.18, -29.96, -43.54, 47.77], [-33.58, -26.87, -32.96, 22.51]], device=torch_device, dtype=torch.float
)

logit_normalization = LogitNormalization()
normalized_scores = logit_normalization(input_ids, scores).exp()

ones = torch.ones(scores.shape[0], device=torch_device, dtype=torch.float)
self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones))

self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))

0 comments on commit a0a9270

Please sign in to comment.