-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TF: XLA bad words logits processor and list of processors #16974
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@Rocketknight1 @patrickvonplaten I'm stuck on the ngram logits processor, so I'd like to request your suggestions regarding what to try out next :D The bad words logits processor is ready and XLA-compatible. Context:
Things I've tried (without any symptom change):
|
I'd be very much in favor of just not converting the I think it's now more important to think about how to advertise, document XLA TF generate well and not loose too much time on this. |
Also not that many models use this processor (only know of BART and T5 for some summarization tasks) |
Agree that it's not necessary to convert this one, but examining it, I suspect that there are some sneaky changes in output size depending on inputs, and XLA is struggling to deal with it. It seems very tough to convert to XLA, but if we decide we need it later let me know and I'll do my best to dig into it. |
Great 👍 I'm going to revert that one, add a TODO pointing at this PR, add a few final tests for the list of logits processors with XLA, and will ping you back. |
…sponding changes)
@Rocketknight1 @patrickvonplaten ready for review |
@@ -401,6 +421,11 @@ def _get_generated_ngrams(hypo_idx): | |||
|
|||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: | |||
|
|||
# TODO (joao): enable XLA on this logits processor. See discussion and attempts in | |||
# https://github.com/huggingface/transformers/pull/16974 | |||
if not tf.executing_eagerly(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
100% with leaving as is for now!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Great job on getting no_bad_word_tokens
to work
|
||
return banned_tokens | ||
# Compares the current row against all bad word sequences, obtaining a mask with the matches. | ||
match_mask = tf.map_fn(_tokens_match, tf.range(self.bad_word_seqs_ids.shape[0]), fn_output_signature=tf.bool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If performance is slow, allowing more parallel_iterations
here might improve things, since this is a lightweight comparison run over a potentially large number of bad_words
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this looks good, especially since the tests are present and passing! The need to filter an arbitrary number of words, each of which can span multiple tokens is very challenging to implement in XLA, so the fact that it's working at all seems almost miraculous, lol.
What does this PR do?
This PR converts to XLA-compatible the
bad_words
logits processor. As per the discussion below, I was unable to convert thengrams
one -- added an exception and a TODO.Also makes a change to the list of processors -- XLA raised issues when the processors had different arguments, so I had to add
cur_len
to all processors. After the change, the list wrapper is also compatible with XLA.