Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF: XLA bad words logits processor and list of processors #16974

Merged
merged 10 commits into from
Apr 29, 2022

Conversation

gante
Copy link
Member

@gante gante commented Apr 27, 2022

What does this PR do?

This PR converts to XLA-compatible the bad_words logits processor. As per the discussion below, I was unable to convert the ngrams one -- added an exception and a TODO.

Also makes a change to the list of processors -- XLA raised issues when the processors had different arguments, so I had to add cur_len to all processors. After the change, the list wrapper is also compatible with XLA.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 27, 2022

The documentation is not available anymore as the PR was closed or merged.

@gante
Copy link
Member Author

gante commented Apr 27, 2022

@Rocketknight1 @patrickvonplaten I'm stuck on the ngram logits processor, so I'd like to request your suggestions regarding what to try out next :D The bad words logits processor is ready and XLA-compatible.

Context:

  1. Without XLA, it works well;
  2. With XLA, yields incorrect outputs (it masks the wrong tokens in some cases). It is not a CPU/GPU thing -- it has the same output regardless of the hardware;
  3. The XLA/non-XLA mismatch is at the output of _calc_row_banned_ngram_tokens, which gets the tokens that should be banned for each row;
  4. All intermediary variables I was able to pull out had the same contents. However, if I try to pull out all ngrams, I get a core dumped on XLA 🤔

Things I've tried (without any symptom change):

  1. The current implementation is a tf.while_loop with tf.TensorArray. On ddc8911, we can see my original implementation with a tf.map_fn (which is closer to the original code). Both versions have the exact same symptoms described above, and return the same errors for the same inputs when XLA is on (!);
  2. Pulling the initialization of the tf.TensorArray to the start of __call__, pass ngram_size as an argument, and use tf.function as a decorator to __call__. The two first changes are to attempt a retrace trigger, the last one to rule out problems associated with attempting to compile a class instance (as opposed to a function);
  3. Using tf.shape instead of tensor.shape, as the former is more suited for symbolic tensors;
  4. Using batches with a single row as input;
  5. Looking for other ways to implement the sliding window on the inputs (i.e. getting the ngrams), with no success.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Apr 28, 2022

I'd be very much in favor of just not converting the ngram Processor. I don't think it's a necessary requirement to publish the new TF generate method. Let's maybe leave this as a hard second issue in case the community is very interested in this feature.

I think it's now more important to think about how to advertise, document XLA TF generate well and not loose too much time on this.

@patrickvonplaten
Copy link
Contributor

Also not that many models use this processor (only know of BART and T5 for some summarization tasks)

@Rocketknight1
Copy link
Member

Agree that it's not necessary to convert this one, but examining it, I suspect that there are some sneaky changes in output size depending on inputs, and XLA is struggling to deal with it. It seems very tough to convert to XLA, but if we decide we need it later let me know and I'll do my best to dig into it.

@gante
Copy link
Member Author

gante commented Apr 28, 2022

Great 👍 I'm going to revert that one, add a TODO pointing at this PR, add a few final tests for the list of logits processors with XLA, and will ping you back.

@gante gante changed the title TF: XLA bad words and ngram logits processors TF: XLA bad words logits processor and list of processors Apr 28, 2022
@gante gante marked this pull request as ready for review April 28, 2022 19:01
@gante
Copy link
Member Author

gante commented Apr 28, 2022

@Rocketknight1 @patrickvonplaten ready for review

@@ -401,6 +421,11 @@ def _get_generated_ngrams(hypo_idx):

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:

# TODO (joao): enable XLA on this logits processor. See discussion and attempts in
# https://github.com/huggingface/transformers/pull/16974
if not tf.executing_eagerly():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% with leaving as is for now!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Great job on getting no_bad_word_tokens to work


return banned_tokens
# Compares the current row against all bad word sequences, obtaining a mask with the matches.
match_mask = tf.map_fn(_tokens_match, tf.range(self.bad_word_seqs_ids.shape[0]), fn_output_signature=tf.bool)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If performance is slow, allowing more parallel_iterations here might improve things, since this is a lightweight comparison run over a potentially large number of bad_words.

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks good, especially since the tests are present and passing! The need to filter an arbitrary number of words, each of which can span multiple tokens is very challenging to implement in XLA, so the fact that it's working at all seems almost miraculous, lol.

@gante gante merged commit fb0ae12 into huggingface:main Apr 29, 2022
@gante gante deleted the xla_bad_words_ngrams branch April 29, 2022 14:55
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request May 3, 2022
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants