-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
shift_tokens_right function missing for mt5 models #15771
Comments
Good catch, thanks for opening the issue! |
Hey Suraj, Sure ! Will do it tomorrow and keep you posted. Thanks ! |
Also noticed that |
Ahh, Good catch! It would be better two open a new PR for this since these are two different changes. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Bump - will do PR soon |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Will make a PR for this shortly :) |
Environment info
transformers
version: 4.17.0.dev0Who can help
@patil-suraj
Information
Hello
I am trying to finetune with Flax on TPU a mt5-small model on a summarization task, using the examples/flax/summarization/run_summarization_flax.py script.
When I run the script, I get an error about the fact that shift_tokens_right is not defined for mt5 models:
File "/home/prod/transformers/examples/flax/summarization/run_summarization_flax.py", line 521, in main # shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") AttributeError: module 'transformers.models.mt5.modeling_flax_mt5' has no attribute 'shift_tokens_right'
Moreover, the current flax summarization script has a typo line 516 :
See https://github.com/huggingface/transformers/blob/master/examples/flax/summarization/run_summarization_flax.py#L516 :
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
(tight should be right !)
I have been able to fix the issue by copying the function shift_tokens_right defined in
src/transformers/models/t5/modeling_flax_mt5.py
into the filesrc/transformers/models/mt5/modeling_flax_mt5.py
.Now the Flax summarization script works fine.
Hope you can fix the mt5 code accordingly !
The text was updated successfully, but these errors were encountered: