Skip to content

Commit

Permalink
TF: XLA bad words logits processor and list of processors (huggingfac…
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored and elusenji committed Jun 12, 2022
1 parent c3dc365 commit b3f5c78
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 75 deletions.
135 changes: 80 additions & 55 deletions src/transformers/generation_tf_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import inspect
from typing import List
from typing import List, Tuple

import numpy as np
import tensorflow as tf
Expand All @@ -38,7 +38,10 @@
[What are input IDs?](../glossary#input-ids)
scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search
search or log softmax for each vocabulary token when using beam search.
cur_len (`int`):
The current length of valid input sequence tokens. In the TF implementation, the input_ids' sequence length
is the maximum length generate can produce, and we need to know which of its tokens are valid.
kwargs:
Additional logits processor specific kwargs.
Expand All @@ -51,7 +54,7 @@ class TFLogitsProcessor:
"""Abstract base class for all logit processors that can be applied during generation."""

@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
"""TF method for processing logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
Expand All @@ -62,7 +65,7 @@ class TFLogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""

@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
"""TF method for warping logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
Expand All @@ -77,18 +80,18 @@ class TFLogitsProcessorList(list):
"""

@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, **kwargs) -> tf.Tensor:
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int, **kwargs) -> tf.Tensor:
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2:
if len(function_args) > 3:
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
raise ValueError(
f"Make sure that all the required parameters: {list(function_args.keys())} for "
f"{processor.__class__} are passed to the logits processor."
)
scores = processor(input_ids, scores, **kwargs)
scores = processor(input_ids, scores, cur_len, **kwargs)
else:
scores = processor(input_ids, scores)
scores = processor(input_ids, scores, cur_len)
return scores


Expand All @@ -107,7 +110,7 @@ def __init__(self, temperature: float):

self.temperature = temperature

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
scores = scores / self.temperature
return scores

Expand All @@ -133,7 +136,7 @@ def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_t
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check
# Boolean mask containing all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
Expand Down Expand Up @@ -163,7 +166,7 @@ def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])

mask_scores = tf.fill(scores.shape, self.filter_value)
Expand Down Expand Up @@ -305,58 +308,75 @@ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)

self.bad_words_ids = bad_words_ids

def calc_banned_bad_words_ids(self, prev_input_ids):
banned_tokens = []

def _tokens_match(prev_tokens, tokens):
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
if len(tokens) > len(prev_tokens):
# if bad word tokens are longer than prev tokens they can't be equal
return False

if prev_tokens[-len(tokens) :] == tokens:
# if tokens match
return True
else:
return False

for prev_input_ids_slice in prev_input_ids:
banned_tokens_slice = []

for banned_token_seq in self.bad_words_ids:
assert (
len(banned_token_seq) > 0
), f"Banned words token sequences {self.bad_words_ids} cannot have an empty list"
# stores the information about bad words in three tensors:
# 1. a rectangular tensor with the forbidden sequences (padded with `-1`), for full data comparisons
self.bad_word_seqs_ids = tf.ragged.constant(bad_words_ids).to_tensor(default_value=-1)
# 2. a tensor with the unpadded length of each forbidden sequence, for quick length comparisons
bad_word_seqs_len = [len(bad_words) for bad_words in bad_words_ids]
if any([word_len == 0 for word_len in bad_word_seqs_len]):
raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")
self.bad_word_seqs_len = tf.convert_to_tensor(bad_word_seqs_len, dtype=tf.int32)
# 3. a tensor containing the last token for each sequence, for easy access to the tokens that may be banned
self.seq_forbidden_tokens = tf.convert_to_tensor([bad_words[-1] for bad_words in bad_words_ids])

def _calc_row_banned_bad_tokens(self, row_input_ids: tf.Tensor) -> tf.Tensor:
def _tokens_match(bad_word_seq_number):
def _len_one():
# If the bad sequence only has one token, always mask it
return tf.cond(
tf.math.equal(self.bad_word_seqs_len[bad_word_seq_number], 1),
lambda: tf.ones((), dtype=tf.bool),
_len_greater_than_cur_len,
)

if _tokens_match(prev_input_ids_slice.numpy().tolist(), banned_token_seq[:-1]) is False:
# if tokens do not match continue
continue
def _len_greater_than_cur_len():
# Otherwise, if the bad sequence is longer than the current length they can't ever match
return tf.cond(
tf.math.greater(self.bad_word_seqs_len[bad_word_seq_number], row_input_ids.shape[0]),
lambda: tf.zeros((), dtype=tf.bool),
_match_found,
)

banned_tokens_slice.append(banned_token_seq[-1])
def _match_found():
# Finaly, runs the actual comparison. Can only be called if the previous comparisons do not yield
# an answer (otherwise we get indexing exceptions)
compare_len = self.bad_word_seqs_len[bad_word_seq_number] - 1
return tf.cond(
tf.math.reduce_all(
tf.math.equal(
row_input_ids[-compare_len:], self.bad_word_seqs_ids[bad_word_seq_number, :compare_len]
)
),
lambda: tf.ones((), dtype=tf.bool),
lambda: tf.zeros((), dtype=tf.bool),
)

banned_tokens.append(banned_tokens_slice)
match = _len_one()
return match

return banned_tokens
# Compares the current row against all bad word sequences, obtaining a mask with the matches.
match_mask = tf.map_fn(_tokens_match, tf.range(self.bad_word_seqs_ids.shape[0]), fn_output_signature=tf.bool)
row_banned_tokens = self.seq_forbidden_tokens[match_mask]
return row_banned_tokens

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:

vocab_size = scores.shape[-1]

# calculate a list of banned tokens according to bad words
banned_tokens = self.calc_banned_bad_words_ids(input_ids[:, :cur_len])

banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
banned_tokens_indices_mask.append(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
# We want to mask some banned tokens, at a score level. Since the banned tokens depend on the previous
# `input_ids`, they may have a different length for each row, and they may even be empty for some rows.
# To remain simple and XLA-compatible, we work on a per-row fashion.
# TODO (Joao): this function might trigger XLA retracing as `cur_len` increases. Fix it if it becomes
# a frequent choke point. (make `cur_len` a tensor?)
def _get_row_updated_score(row_inputs: Tuple[tf.Tensor]) -> tf.Tensor:
row_input_ids, row_score = row_inputs
banned_tokens = self._calc_row_banned_bad_tokens(row_input_ids[:cur_len])
banned_tokens_mask = tf.scatter_nd(
indices=tf.expand_dims(banned_tokens, axis=-1),
updates=tf.ones_like(banned_tokens, dtype=tf.bool),
shape=row_score.shape,
)
row_score = tf.where(banned_tokens_mask, -float("inf"), row_score)
return row_score

scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)

scores = tf.map_fn(_get_row_updated_score, (input_ids, scores), fn_output_signature=tf.float32)
return scores


Expand Down Expand Up @@ -401,6 +421,11 @@ def _get_generated_ngrams(hypo_idx):

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:

# TODO (joao): enable XLA on this logits processor. See discussion and attempts in
# https://github.com/huggingface/transformers/pull/16974
if not tf.executing_eagerly():
raise NotImplementedError("TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.")

batch_size, vocab_size = scores.shape
banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len)

Expand Down
8 changes: 4 additions & 4 deletions src/transformers/generation_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2030,7 +2030,7 @@ def greedy_search_body_fn(generated, finished_sequences, next_tokens, current_po
if not use_xla:
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
input_ids = tf.transpose(input_ids[: current_pos[0]])
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=current_pos[0])
next_tokens_scores = logits_processor(input_ids, next_token_logits, current_pos[0])

# argmax
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
Expand Down Expand Up @@ -2301,8 +2301,8 @@ def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kw
if not use_xla:
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
input_ids = tf.transpose(input_ids[:cur_len])
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len)
next_tokens_scores = logits_warper(input_ids, next_tokens_scores)
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
next_tokens_scores = logits_warper(input_ids, next_tokens_scores, cur_len)

# sample
if seed is not None:
Expand Down Expand Up @@ -2726,7 +2726,7 @@ def beam_search_body_fn(
# add new logprobs to existing running logprobs scores.
log_probs = tf.nn.log_softmax(logits)
log_probs = logits_processor(
flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len=cur_len
flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len
)
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
Expand Down
Loading

0 comments on commit b3f5c78

Please sign in to comment.