Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shift_tokens_right function missing for mt5 models #15771

Closed
bduclaux opened this issue Feb 22, 2022 · 8 comments · Fixed by #17188
Closed

shift_tokens_right function missing for mt5 models #15771

bduclaux opened this issue Feb 22, 2022 · 8 comments · Fixed by #17188

Comments

@bduclaux
Copy link

Environment info

  • transformers version: 4.17.0.dev0
  • Platform: Linux-5.11.0-1028-gcp-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • PyTorch version (GPU?): 1.11.0a0+git06838ce (False)
  • Tensorflow version (GPU?): 2.9.0-dev20220221 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.4.0 (cpu)
  • Jax version: 0.3.1
  • JaxLib version: 0.3.0

Who can help

@patil-suraj

Information

Hello

I am trying to finetune with Flax on TPU a mt5-small model on a summarization task, using the examples/flax/summarization/run_summarization_flax.py script.
When I run the script, I get an error about the fact that shift_tokens_right is not defined for mt5 models:

File "/home/prod/transformers/examples/flax/summarization/run_summarization_flax.py", line 521, in main # shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") AttributeError: module 'transformers.models.mt5.modeling_flax_mt5' has no attribute 'shift_tokens_right'

Moreover, the current flax summarization script has a typo line 516 :

See https://github.com/huggingface/transformers/blob/master/examples/flax/summarization/run_summarization_flax.py#L516 :
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])

(tight should be right !)

I have been able to fix the issue by copying the function shift_tokens_right defined in src/transformers/models/t5/modeling_flax_mt5.py into the file src/transformers/models/mt5/modeling_flax_mt5.py .
Now the Flax summarization script works fine.

Hope you can fix the mt5 code accordingly !

@patil-suraj
Copy link
Contributor

patil-suraj commented Feb 22, 2022

Good catch, thanks for opening the issue!
Would you be interested to open a PR to add this function in flax mT5?

@bduclaux
Copy link
Author

Hey Suraj,

Sure ! Will do it tomorrow and keep you posted. Thanks !

@bduclaux
Copy link
Author

Also noticed that adafactor parameter is not used to instantiate the optimizer in the run_summarization_flax.py script.
Will add it in my PR, to have support for both adamW and adafactor.

@patil-suraj
Copy link
Contributor

Also noticed that adafactor parameter is not used to instantiate the optimizer in the run_summarization_flax.py script. Will add it in my PR, to have support for both adamW and adafactor.

Ahh, Good catch! It would be better two open a new PR for this since these are two different changes.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@bduclaux
Copy link
Author

Bump - will do PR soon

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed May 2, 2022
@patil-suraj patil-suraj reopened this May 11, 2022
@patil-suraj
Copy link
Contributor

Will make a PR for this shortly :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants