diff --git a/src/transformers/generation_tf_logits_process.py b/src/transformers/generation_tf_logits_process.py index 0a5ac8318296..7c0f75290697 100644 --- a/src/transformers/generation_tf_logits_process.py +++ b/src/transformers/generation_tf_logits_process.py @@ -14,7 +14,7 @@ # limitations under the License. import inspect -from typing import List +from typing import List, Tuple import numpy as np import tensorflow as tf @@ -38,7 +38,10 @@ [What are input IDs?](../glossary#input-ids) scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`): Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam - search or log softmax for each vocabulary token when using beam search + search or log softmax for each vocabulary token when using beam search. + cur_len (`int`): + The current length of valid input sequence tokens. In the TF implementation, the input_ids' sequence length + is the maximum length generate can produce, and we need to know which of its tokens are valid. kwargs: Additional logits processor specific kwargs. @@ -51,7 +54,7 @@ class TFLogitsProcessor: """Abstract base class for all logit processors that can be applied during generation.""" @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: """TF method for processing logits.""" raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." @@ -62,7 +65,7 @@ class TFLogitsWarper: """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: """TF method for warping logits.""" raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." @@ -77,18 +80,18 @@ class TFLogitsProcessorList(list): """ @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, **kwargs) -> tf.Tensor: + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int, **kwargs) -> tf.Tensor: for processor in self: function_args = inspect.signature(processor.__call__).parameters - if len(function_args) > 2: + if len(function_args) > 3: if not all(arg in kwargs for arg in list(function_args.keys())[2:]): raise ValueError( f"Make sure that all the required parameters: {list(function_args.keys())} for " f"{processor.__class__} are passed to the logits processor." ) - scores = processor(input_ids, scores, **kwargs) + scores = processor(input_ids, scores, cur_len, **kwargs) else: - scores = processor(input_ids, scores) + scores = processor(input_ids, scores, cur_len) return scores @@ -107,7 +110,7 @@ def __init__(self, temperature: float): self.temperature = temperature - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: scores = scores / self.temperature return scores @@ -133,7 +136,7 @@ def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_t self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check # Boolean mask containing all tokens with a probability less than the last token of the top-k indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:] @@ -163,7 +166,7 @@ def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1]) mask_scores = tf.fill(scores.shape, self.filter_value) @@ -305,58 +308,75 @@ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int): f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." ) - self.bad_words_ids = bad_words_ids - - def calc_banned_bad_words_ids(self, prev_input_ids): - banned_tokens = [] - - def _tokens_match(prev_tokens, tokens): - if len(tokens) == 0: - # if bad word tokens is just one token always ban it - return True - if len(tokens) > len(prev_tokens): - # if bad word tokens are longer than prev tokens they can't be equal - return False - - if prev_tokens[-len(tokens) :] == tokens: - # if tokens match - return True - else: - return False - - for prev_input_ids_slice in prev_input_ids: - banned_tokens_slice = [] - - for banned_token_seq in self.bad_words_ids: - assert ( - len(banned_token_seq) > 0 - ), f"Banned words token sequences {self.bad_words_ids} cannot have an empty list" + # stores the information about bad words in three tensors: + # 1. a rectangular tensor with the forbidden sequences (padded with `-1`), for full data comparisons + self.bad_word_seqs_ids = tf.ragged.constant(bad_words_ids).to_tensor(default_value=-1) + # 2. a tensor with the unpadded length of each forbidden sequence, for quick length comparisons + bad_word_seqs_len = [len(bad_words) for bad_words in bad_words_ids] + if any([word_len == 0 for word_len in bad_word_seqs_len]): + raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list") + self.bad_word_seqs_len = tf.convert_to_tensor(bad_word_seqs_len, dtype=tf.int32) + # 3. a tensor containing the last token for each sequence, for easy access to the tokens that may be banned + self.seq_forbidden_tokens = tf.convert_to_tensor([bad_words[-1] for bad_words in bad_words_ids]) + + def _calc_row_banned_bad_tokens(self, row_input_ids: tf.Tensor) -> tf.Tensor: + def _tokens_match(bad_word_seq_number): + def _len_one(): + # If the bad sequence only has one token, always mask it + return tf.cond( + tf.math.equal(self.bad_word_seqs_len[bad_word_seq_number], 1), + lambda: tf.ones((), dtype=tf.bool), + _len_greater_than_cur_len, + ) - if _tokens_match(prev_input_ids_slice.numpy().tolist(), banned_token_seq[:-1]) is False: - # if tokens do not match continue - continue + def _len_greater_than_cur_len(): + # Otherwise, if the bad sequence is longer than the current length they can't ever match + return tf.cond( + tf.math.greater(self.bad_word_seqs_len[bad_word_seq_number], row_input_ids.shape[0]), + lambda: tf.zeros((), dtype=tf.bool), + _match_found, + ) - banned_tokens_slice.append(banned_token_seq[-1]) + def _match_found(): + # Finaly, runs the actual comparison. Can only be called if the previous comparisons do not yield + # an answer (otherwise we get indexing exceptions) + compare_len = self.bad_word_seqs_len[bad_word_seq_number] - 1 + return tf.cond( + tf.math.reduce_all( + tf.math.equal( + row_input_ids[-compare_len:], self.bad_word_seqs_ids[bad_word_seq_number, :compare_len] + ) + ), + lambda: tf.ones((), dtype=tf.bool), + lambda: tf.zeros((), dtype=tf.bool), + ) - banned_tokens.append(banned_tokens_slice) + match = _len_one() + return match - return banned_tokens + # Compares the current row against all bad word sequences, obtaining a mask with the matches. + match_mask = tf.map_fn(_tokens_match, tf.range(self.bad_word_seqs_ids.shape[0]), fn_output_signature=tf.bool) + row_banned_tokens = self.seq_forbidden_tokens[match_mask] + return row_banned_tokens def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: - - vocab_size = scores.shape[-1] - - # calculate a list of banned tokens according to bad words - banned_tokens = self.calc_banned_bad_words_ids(input_ids[:, :cur_len]) - - banned_tokens_indices_mask = [] - for banned_tokens_slice in banned_tokens: - banned_tokens_indices_mask.append( - [True if token in banned_tokens_slice else False for token in range(vocab_size)] + # We want to mask some banned tokens, at a score level. Since the banned tokens depend on the previous + # `input_ids`, they may have a different length for each row, and they may even be empty for some rows. + # To remain simple and XLA-compatible, we work on a per-row fashion. + # TODO (Joao): this function might trigger XLA retracing as `cur_len` increases. Fix it if it becomes + # a frequent choke point. (make `cur_len` a tensor?) + def _get_row_updated_score(row_inputs: Tuple[tf.Tensor]) -> tf.Tensor: + row_input_ids, row_score = row_inputs + banned_tokens = self._calc_row_banned_bad_tokens(row_input_ids[:cur_len]) + banned_tokens_mask = tf.scatter_nd( + indices=tf.expand_dims(banned_tokens, axis=-1), + updates=tf.ones_like(banned_tokens, dtype=tf.bool), + shape=row_score.shape, ) + row_score = tf.where(banned_tokens_mask, -float("inf"), row_score) + return row_score - scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores) - + scores = tf.map_fn(_get_row_updated_score, (input_ids, scores), fn_output_signature=tf.float32) return scores @@ -401,6 +421,11 @@ def _get_generated_ngrams(hypo_idx): def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: + # TODO (joao): enable XLA on this logits processor. See discussion and attempts in + # https://github.com/huggingface/transformers/pull/16974 + if not tf.executing_eagerly(): + raise NotImplementedError("TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.") + batch_size, vocab_size = scores.shape banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len) diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index fedc3fdf98d7..2a9251eeb5a0 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -2030,7 +2030,7 @@ def greedy_search_body_fn(generated, finished_sequences, next_tokens, current_po if not use_xla: input_ids = tf.reshape(generated.concat(), (-1, batch_size)) input_ids = tf.transpose(input_ids[: current_pos[0]]) - next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=current_pos[0]) + next_tokens_scores = logits_processor(input_ids, next_token_logits, current_pos[0]) # argmax next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32) @@ -2301,8 +2301,8 @@ def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kw if not use_xla: input_ids = tf.reshape(generated.concat(), (-1, batch_size)) input_ids = tf.transpose(input_ids[:cur_len]) - next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len) - next_tokens_scores = logits_warper(input_ids, next_tokens_scores) + next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len) + next_tokens_scores = logits_warper(input_ids, next_tokens_scores, cur_len) # sample if seed is not None: @@ -2726,7 +2726,7 @@ def beam_search_body_fn( # add new logprobs to existing running logprobs scores. log_probs = tf.nn.log_softmax(logits) log_probs = logits_processor( - flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len=cur_len + flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len ) log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) log_probs = log_probs + tf.expand_dims(running_scores, axis=2) diff --git a/tests/generation/test_generation_tf_logits_process.py b/tests/generation/test_generation_tf_logits_process.py index 9fb8e83fccd7..be60335ef2f8 100644 --- a/tests/generation/test_generation_tf_logits_process.py +++ b/tests/generation/test_generation_tf_logits_process.py @@ -75,6 +75,7 @@ def test_min_length_dist_processor(self, use_xla): @parameterized.expand([(False,), (True,)]) def test_temperature_dist_warper(self, use_xla): input_ids = None + cur_len = None length = 20 scores = self._get_uniform_logits(batch_size=2, length=length) @@ -94,8 +95,8 @@ def test_temperature_dist_warper(self, use_xla): temp_dist_warper_sharper = tf.function(temp_dist_warper_sharper, jit_compile=True) temp_dist_warper_smoother = tf.function(temp_dist_warper_smoother, jit_compile=True) - warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores)), axis=-1) - warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores)), axis=-1) + warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores), cur_len), axis=-1) + warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores), cur_len), axis=-1) # uniform distribution stays uniform tf.debugging.assert_near(probs[0, :], warped_prob_sharp[0, :], atol=1e-3) @@ -142,6 +143,7 @@ def test_repetition_penalty_dist_process(self, use_xla): @parameterized.expand([(False,), (True,)]) def test_top_k_dist_warper(self, use_xla): input_ids = None + cur_len = None vocab_size = 10 batch_size = 2 @@ -153,7 +155,7 @@ def test_top_k_dist_warper(self, use_xla): if use_xla: top_k_warp = tf.function(top_k_warp, jit_compile=True) - scores = top_k_warp(input_ids, ramp_logits) + scores = top_k_warp(input_ids, ramp_logits, cur_len) # check that correct tokens are filtered self.assertListEqual(tf.math.is_inf(scores[0]).numpy().tolist(), 7 * [True] + 3 * [False]) @@ -167,12 +169,12 @@ def test_top_k_dist_warper(self, use_xla): if use_xla: top_k_warp_safety_check = tf.function(top_k_warp_safety_check, jit_compile=True) - scores = top_k_warp_safety_check(input_ids, logits) + scores = top_k_warp_safety_check(input_ids, logits, cur_len) # uniform dist is not changed self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [0, 0]) ramp_logits = np.broadcast_to(np.arange(length, dtype=np.float32), (batch_size, length)).copy() - scores = top_k_warp_safety_check(input_ids, ramp_logits) + scores = top_k_warp_safety_check(input_ids, ramp_logits, cur_len) # min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [2, 2]) @@ -180,6 +182,7 @@ def test_top_k_dist_warper(self, use_xla): @parameterized.expand([(False,), (True,)]) def test_top_p_dist_warper(self, use_xla): input_ids = None + cur_len = None vocab_size = 10 batch_size = 2 @@ -189,7 +192,7 @@ def test_top_p_dist_warper(self, use_xla): top_p_warp = TFTopPLogitsWarper(0.7) if use_xla: top_p_warp = tf.function(top_p_warp, jit_compile=True) - filtered_dist = tf.exp(top_p_warp(input_ids, dist)) + filtered_dist = tf.exp(top_p_warp(input_ids, dist, cur_len)) # dist should be filtered to keep min num values so that sum is >= 0.7 # exp (-inf) => 0 @@ -208,7 +211,7 @@ def test_top_p_dist_warper(self, use_xla): top_p_warp = TFTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0) if use_xla: top_p_warp = tf.function(top_p_warp, jit_compile=True) - filtered_dist = top_p_warp(input_ids, ramp_logits) + filtered_dist = top_p_warp(input_ids, ramp_logits, cur_len) # first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps # 2. @@ -242,7 +245,8 @@ def test_no_repeat_ngram_dist_processor(self): tf.math.is_inf(filtered_scores_3_gram).numpy().tolist(), [[False, False, False], [True, False, False]] ) - def test_no_bad_words_dist_processor(self): + @parameterized.expand([(False,), (True,)]) + def test_no_bad_words_dist_processor(self, use_xla): vocab_size = 5 batch_size = 2 eos_token_id = 4 @@ -255,6 +259,8 @@ def test_no_bad_words_dist_processor(self): scores = self._get_uniform_logits(batch_size, vocab_size) no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id) + if use_xla: + no_bad_words_dist_proc = tf.function(no_bad_words_dist_proc, jit_compile=True) filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores), cur_len) @@ -322,7 +328,9 @@ def test_forced_eos_token_logits_processor(self, use_xla): scores = logits_processor(input_ids, scores, cur_len) self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores)))) - def test_processor_list(self): + @parameterized.expand([(False,), (True,)]) + def test_processor_list(self, use_xla): + # TODO (Joao): reintroduce TFNoRepeatNGramLogitsProcessor when it gets compatible with XLA batch_size = 4 cur_len = 10 vocab_size = 15 @@ -341,16 +349,24 @@ def test_processor_list(self): rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) top_k_warp = TFTopKLogitsWarper(3) top_p_warp = TFTopPLogitsWarper(0.8) - no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2) + # no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2) no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id) + if use_xla: + min_dist_proc = tf.function(min_dist_proc, jit_compile=True) + temp_dist_warp = tf.function(temp_dist_warp, jit_compile=True) + rep_penalty_proc = tf.function(rep_penalty_proc, jit_compile=True) + top_k_warp = tf.function(top_k_warp, jit_compile=True) + top_p_warp = tf.function(top_p_warp, jit_compile=True) + # no_repeat_proc = tf.function(no_repeat_proc, jit_compile=True) + no_bad_words_dist_proc = tf.function(no_bad_words_dist_proc, jit_compile=True) # no processor list scores = min_dist_proc(input_ids, scores, cur_len) - scores = temp_dist_warp(input_ids, scores) + scores = temp_dist_warp(input_ids, scores, cur_len) scores = rep_penalty_proc(input_ids, scores, cur_len) - scores = top_k_warp(input_ids, scores) - scores = top_p_warp(input_ids, scores) - scores = no_repeat_proc(input_ids, scores, cur_len) + scores = top_k_warp(input_ids, scores, cur_len) + scores = top_p_warp(input_ids, scores, cur_len) + # scores = no_repeat_proc(input_ids, scores, cur_len) scores = no_bad_words_dist_proc(input_ids, scores, cur_len) # with processor list @@ -361,11 +377,11 @@ def test_processor_list(self): rep_penalty_proc, top_k_warp, top_p_warp, - no_repeat_proc, + # no_repeat_proc, no_bad_words_dist_proc, ] ) - scores_comp = processor(input_ids, scores_comp, cur_len=cur_len) + scores_comp = processor(input_ids, scores_comp, cur_len) # remove inf scores = tf.where(tf.math.is_inf(scores), -1e9, scores)