From f47fd59616a364ac91e0c8e9bd49f3e35de83853 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 11 Apr 2022 15:56:53 +0000 Subject: [PATCH] Handle case without past --- src/transformers/generation_tf_utils.py | 54 ++++++++++++++++++++----- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 8668e6f8dcc2..dbc8ab78c56a 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -2510,6 +2510,7 @@ def gather_fn(tensor): # 3. init tensors to use for "xla-compileable" generate function batch_size, num_beams, cur_len = input_ids.shape + input_ids_length = cur_len # per batch, beam-item holding current token in loop, pre-populated with `pad_token_id` sequences = tf.TensorArray( @@ -2564,7 +2565,14 @@ def gather_fn(tensor): # 4. define "xla-compile-able" stop-condition and auto-regressive function # define stop-condition and auto-regressive function def beam_search_cond_fn( - cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs + cur_len, + running_sequences, + running_scores, + sequences, + scores, + is_sent_finished, + model_kwargs, + input_ids_length, ): """ Beam Search termination condition function -- halts the generation loop if any of these conditions becomes @@ -2593,7 +2601,7 @@ def beam_search_body_fn( scores, is_sent_finished, model_kwargs, - input_ids_length=1, + input_ids_length, intermediary_running_sequences=None, ): """ @@ -2750,9 +2758,11 @@ def beam_search_body_fn( # if we don't cache past key values we need the whole input if model_kwargs.get("past", None) is None: - input_ids_length = cur_len + 1 + next_input_ids_length = cur_len + 1 # let's throw out `past` since we don't want `None` tensors model_kwargs.pop("past", None) + else: + next_input_ids_length = 1 # 9. Prepare the `tf.TensorArray` for the next iteration next_sequences = sequences.unstack(tf.transpose(next_sequences_seq_last, perm=[2, 0, 1])) @@ -2768,6 +2778,7 @@ def beam_search_body_fn( next_scores, next_is_sent_finished, next_model_kwargs, + next_input_ids_length, ) # 5. run generation @@ -2776,8 +2787,7 @@ def beam_search_body_fn( beam_search_body_fn, intermediary_running_sequences=intermediary_running_sequences ) - # 1st generation step has to be run before to initialize `past` - beam_search_body_fn_first_iter = partial(beam_search_body_fn, input_ids_length=cur_len) + # 1st generation step has to be run before to initialize `past` (if active) ( cur_len, running_sequences, @@ -2786,20 +2796,44 @@ def beam_search_body_fn( scores, is_sent_finished, model_kwargs, - ) = beam_search_body_fn_first_iter( - cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs + input_ids_length, + ) = beam_search_body_fn( + cur_len, + running_sequences, + running_scores, + sequences, + scores, + is_sent_finished, + model_kwargs, + input_ids_length, ) # 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does # NOT yield EOS token though) if beam_search_cond_fn( - cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs + cur_len, + running_sequences, + running_scores, + sequences, + scores, + is_sent_finished, + model_kwargs, + input_ids_length, ): maximum_iterations = max_length - cur_len - cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _ = tf.while_loop( + cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _, _ = tf.while_loop( beam_search_cond_fn, beam_search_body_fn, - (cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs), + ( + cur_len, + running_sequences, + running_scores, + sequences, + scores, + is_sent_finished, + model_kwargs, + input_ids_length, + ), maximum_iterations=maximum_iterations, )