-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix and improve CTRL doctests #16573
Fix and improve CTRL doctests #16573
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding the examples, @jeremyadamsfisher!
I run it locally and the tests all pass 🚀
BTW, you forgot to add the ctrl model to utils/documentation_tests.txt
.
LGTM! I left a few tiny comments.
I would also like to have my colleagues to review this PR too 🙂.
|
||
>>> inputs = tokenizer("Opinion my dog is cute", return_tensors="pt") | ||
>>> outputs = model(**inputs, labels=inputs["input_ids"]) | ||
>>> loss = outputs.loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can actually provide the expected value
>>> inputs = tokenizer("Opinion my dog is cute", return_tensors="pt") | ||
>>> outputs = model(**inputs, labels=inputs["input_ids"]) | ||
>>> loss = outputs.loss | ||
>>> logits = outputs.logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can actually provide the expected value (shape of the logit)
>>> tokenizer = CTRLTokenizer.from_pretrained("sshleifer/tiny-ctrl") | ||
>>> model = CTRLForSequenceClassification.from_pretrained("sshleifer/tiny-ctrl") | ||
|
||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason here for not having Opinion
at the beginning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, this was an oversight
>>> from transformers import CTRLTokenizer, CTRLModel | ||
>>> import torch | ||
|
||
>>> tokenizer = CTRLTokenizer.from_pretrained("sshleifer/tiny-ctrl") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can actually use ctrl
as the checkpoint for the base model.
Could you also take a look when you have some time? In particular, I don't find any existing documentation mentioning the usage of (still don't feel very comfortable without seeing this usage) |
Let's ping @LysandreJik as he might know more on CTRL ;-) |
Thanks for the review! I'll address these comments asap. As for the control code coming first, there's actually an example right here:
|
Thank you for this info. @jeremyadamsfisher |
558f4af
to
cce1ab7
Compare
Thanks again for the feedback. I've added assertions on lines 383 and 562-565 and changed the model from from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update. I added a few more nit comments.
Let's wait a bit to see if Lysandre has any comment, but it is very appreciated that you provided the information about the ctrl code usage :-)
>>> inputs = tokenizer("Opinion my dog is cute", return_tensors="pt") | ||
>>> outputs = model(**inputs, labels=inputs["input_ids"]) | ||
>>> outputs.loss.item() | ||
5.788386821746826 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is too precise 😅, and very likely to fail the test (especially when running on other machines). We use the following instead
>>> round(outputs.loss.item(), 2)
5.788386821746826 | ||
|
||
>>> outputs.logits.shape | ||
torch.Size([1, 5, 246534]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use list(outputs.logits.shape)
and put the output as [1, 5, 246534]
.
(simpler than torch.Size([1, 5, 246534])
)
|
||
>>> last_hidden_states = outputs.last_hidden_state | ||
>>> last_hidden_states.shape | ||
torch.Size([1, 5, 1280]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use list(last_hidden_states.shape)
and put the output as [1, 5, 1280]
.
(simpler than torch.Size([1, 5, 1280])
)
Sure thing, those are easy changes. To clarify the control code coming first, would it make sense to add something like this? >>> # CTRL was trained with control codes as the first token
>>> inputs = tokenizer("Opinion my dog is cute", return_tensors="pt")
>>> assert inputs[0] in tokenizer.control_codes.values() |
7c5ed04
to
b9424d4
Compare
That doesn't seem to work, will tinker with this a bit more:
|
Aha! This works:
Added this to the doctest wherever there was a |
@ydshieh heads up -- I've addressed your second set of comments and the checks have passed :) Still waiting on @LysandreJik would love to hear your thoughts |
>>> tokenizer = CTRLTokenizer.from_pretrained("ctrl") | ||
>>> model = CTRLLMHeadModel.from_pretrained("ctrl") | ||
|
||
>>> # CTRL was trained with control codes as the first token |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice - exactly right to use the CTRL control codes here :-) Could we maybe also add a generate
example to the model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool to leverage control codes for the examples!
b45df8a
to
dddc87c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nice - looks good to me!
I let @ydshieh take a final look here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 💯 Thank you, @jeremyadamsfisher! Run locally (GCP VM 32GB RAM) and all tests pass!
One thing I am a bit worried: On CPU, some tests (sequence classification) require ~24GB memory to run. I am not sure if we will get GPU OOM when it runs on CI.
Any comment here, @patrickvonplaten ?
>>> sequence_ids = model.generate(inputs["input_ids"]) | ||
>>> sequences = tokenizer.batch_decode(sequence_ids) | ||
>>> sequences | ||
['Wikipedia The llama is a member of the family Bovidae. It is native to the Andes of Peru,'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) I would use sequences[0]
and show the output as a string instead of a list
|
||
```python | ||
>>> from transformers import CTRLTokenizer, CTRLModel | ||
>>> import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line could be removed (torch
imported but not used)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten Do you have strong opinion here? (should we do the same for doc.py
?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes would be cleaner to remove torch
, but happy to leave it for a future PR as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep it in the test for now even if it takes 24GB of RAM - if it fails we can adapt afterward
Hi, @jeremyadamsfisher Could you try to resolve the conflicts in
Then we are ready to merge :-) Thanks! (I can help on this if you need, just let me know) |
Hi, @jeremyadamsfisher Just to let you know: I resolved the conflict lines, and pushed to this PR branch. If you need to make some more changes (if any), don't forget to |
Merged! Thank you again, @jeremyadamsfisher ! |
Thank you @ydshieh! Apologies I wasn't able to fix the merge conflicts myself, but it is much appreciated! |
* Improve CTRL doctests * Fix `CTRLForSequenceClassification` flakiness with inconsistent losses * Remove unused * Fixup * Add CTRL to documentation_tests.txt * Fix control code not being first * Add output assertions * Change from sshleifer/tiny-ctrl -> ctrl * Run `make fixup` * apply `list` to output logits shape for clarity * Reduce output loss precision to make assertion more robust * Add assertion of control code being first * Fix docstyle * upper case sentence following control code * Weird bug fixes * Add a better generation example Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
What does this PR do?
This PR addresses the CTRL doc test failures and replaces the example text with one that is more appropriate for CTRL specifically (i.e., by prefacing it with a control code)
Motivated as part of the doctest sprint: #16292
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@patrickvonplaten @ydshieh @patil-suraj