-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Trainer compatible with sharded checkpoints #17053
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Only added a few comments regarding docs as I think we haven't put the emphasis on what it is exactly and how it should be used for it to be easily used by new users.
""" | ||
This is the same as | ||
[`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict) | ||
but for a sharded checkpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be nice to have a documentation for sharded checkpoints so that users understand what they are. For example here having the sharded checkpoint
redirect to a small blurb mentioning things like the following:
- What is it?
- Weights files that are split in multiple checkpoint
- Index showing how weights are linked
- Why is it important?
- Better to work with smaller files for memory
- Simpler to push to the hub
- How to work with it?
- Showing how to use
from_pretrained
andsave_pretrained
for sharding - push to hub
- now trainer
- Showing how to use
Let me know if that's something that already exists, and if not I'm happy to help contribute it (or to contribute it altogether).
@@ -327,6 +327,63 @@ def get_checkpoint_shard_files( | |||
return cached_filenames, sharded_metadata | |||
|
|||
|
|||
def load_sharded_checkpoint(model, folder, strict=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be in the docs somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
* Make Trainer compatible with sharded checkpoints * Add doc
* Make Trainer compatible with sharded checkpoints * Add doc
What does this PR do?
The
Trainer
is currently incompatible with the new sharded checkpoint feature in two places:In both cases, the model state dict is loaded back inside the model but there is no model save file if the model was above the default size for sharding, resulting in errors (as was pointed out by #16976 ).
This PR addresses this by:
load_sharded_checkpoint
that does the same thing asmodel.load_state_dict
for regular model files, but loads a sharded checkpoint (and errors in case of missing/unexpected keys whenstrict=True
).A test is added to make sure resuming works from a sharded checkpoint.
Fixes #16976