Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Trainer compatible with sharded checkpoints #17053

Merged
merged 2 commits into from
May 3, 2022

Conversation

sgugger
Copy link
Collaborator

@sgugger sgugger commented May 2, 2022

What does this PR do?

The Trainer is currently incompatible with the new sharded checkpoint feature in two places:

  • resuming from a checkpoint
  • loading the best model at the end of training

In both cases, the model state dict is loaded back inside the model but there is no model save file if the model was above the default size for sharding, resulting in errors (as was pointed out by #16976 ).

This PR addresses this by:

  1. Creating a new function load_sharded_checkpoint that does the same thing as model.load_state_dict for regular model files, but loads a sharded checkpoint (and errors in case of missing/unexpected keys when strict=True).
  2. Use that function inside the Trainer in the two places mentioned above.

A test is added to make sure resuming works from a sharded checkpoint.

Fixes #16976

@sgugger sgugger requested a review from LysandreJik May 2, 2022 18:13
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 2, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Only added a few comments regarding docs as I think we haven't put the emphasis on what it is exactly and how it should be used for it to be easily used by new users.

"""
This is the same as
[`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
but for a sharded checkpoint.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to have a documentation for sharded checkpoints so that users understand what they are. For example here having the sharded checkpoint redirect to a small blurb mentioning things like the following:

  • What is it?
    • Weights files that are split in multiple checkpoint
    • Index showing how weights are linked
  • Why is it important?
    • Better to work with smaller files for memory
    • Simpler to push to the hub
  • How to work with it?
    • Showing how to use from_pretrained and save_pretrained for sharding
    • push to hub
    • now trainer

Let me know if that's something that already exists, and if not I'm happy to help contribute it (or to contribute it altogether).

@@ -327,6 +327,63 @@ def get_checkpoint_shard_files(
return cached_filenames, sharded_metadata


def load_sharded_checkpoint(model, folder, strict=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be in the docs somewhere?

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@sgugger sgugger merged commit a8fa2f9 into main May 3, 2022
@sgugger sgugger deleted the resume_from_sharded_checkpoint branch May 3, 2022 13:55
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request May 3, 2022
* Make Trainer compatible with sharded checkpoints

* Add doc
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* Make Trainer compatible with sharded checkpoints

* Add doc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug: Finetuning large models resume checkpoint error
3 participants