Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨 feat: add non-breaking support to serialize metadata in loras. #9143

Closed
wants to merge 19 commits into from

Conversation

sayakpaul
Copy link
Member

What does this PR do?

To avoid bugs like:

In my view, it's a non-breaking change. Otherwise, training tests would have failed because of unpacking mismatches.

I attempted this back in the day: #6135. But we had to close it. Now, its necessity is evident.

@sayakpaul sayakpaul requested review from BenjaminBossan and DN6 August 9, 2024 13:18
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice quick work. Looks almost perfect to me already :) I can't really say if all places have been adjusted or some were missed, I'll leave that to someone who has better knowledge of the diffusers code base.

Apart from that, I only have a few comments, please check.

I also noticed that there is already a mechanism for providing alphas via network_alpha_dict. IIUC, this needs to be user provided for now and cannot be inferred automatically from the checkpoint. Still, I wonder if we can set network_alpha_dict from config instead of having two somewhat independent code paths to achieve the same goal.

The docstrings for config are still empty, please add them, as it's not immediately obvious what values are expected.

# Try to retrive config.
alpha_retrieved = False
if config is not None:
lora_alpha = config["lora_alpha"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this also needs to consider alpha_pattern -- if not intended to support for now, at least raise an error if alpha_pattern is given? If support is added, the unit test should include different alpha values (or a separate test could be added for this).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alpha_pattern cannot be provided through get_peft_kwargs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, maybe I'm misunderstanding, but we're passing a LoraConfig instance, how can we know that this does not have alpha_pattern?

I wonder if this should be:

    # Try to retrieve config.
    alpha_retrieved = False
    if config is not None:
        lora_alpha = config["lora_alpha"] if "lora_alpha" in config else lora_alpha
        alpha_retrieved = True
+       if config.get("alpha_pattern", None):
+           alpha_pattern = config["alpha_pattern"]

Similar argument for rank_pattern. In general, it's not clear to me how we should handle it if rank_pattern/alpha_pattern differ from rank_dict/ network_alpha_dict (or is it not possible)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Let me accommodate those changes and have it ready for your review.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, IIUC there is still an issue. Say get_peft_kwargs is called with both network_alpha_dict and config being passed. As is, network_alpha_dict is completely ignored and alpha_pattern remains an empty dict. I think what should happen is that either

  1. alpha_pattern is taken from network_alpha_dict (same as previously)
  2. alpha_pattern is taken from config if config.alpha_pattern is not None. If it is None, network_alpha_dict should be used.

So either network_alpha_dict or config.alpha_pattern should take precedence. And if both are given, potentially warn about the one being ignored. WDYT?

Depending on the choice here, rank_pattern should also be adjusted for consistency.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alpha_pattern is taken from config if config.alpha_pattern is not None. If it is None, network_alpha_dict should be used.

This is not what I am doing. That is because network_alpha_dict can only be true for non-diffusers checkpoints. For those checkpoints, metadata won't have a PEFT config.

Long story cut, short, rank_pattern and alpha_pattern from config (if found) will be simply ignored for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying, so it is assumed that rank_pattern/alpha_pattern are never passed when config is passed, and vice versa. In that case, this could be checked and an error raised, or at least a comment added, WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add a comment. For now, a warning like we are doing now would suffice.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment in 178a459.

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
@sayakpaul sayakpaul changed the title feat: add non-breaking support to serialize metadata in loras. 🚨 feat: add non-breaking support to serialize metadata in loras. Aug 9, 2024
@sayakpaul
Copy link
Member Author

Thanks for suggestions, @BenjaminBossan! I have added the docstrings. LMK what you think.

@bghira
Copy link
Contributor

bghira commented Aug 11, 2024

for now we've set up the default alpha to equate to rank since i was sorta confused or misunderstood what the purpose of the alpha parameter was for. as it turns out some trainers implement it a bit differently, eg. permanently scale the weights before exporting/saving. but rank=alpha works for us for now, i'll implement this when it's merged in.

@sayakpaul
Copy link
Member Author

Right, thanks!

Comment on lines +315 to +317
if file_extension == SAFETENSORS_FILE_EXTENSION:
with safetensors.safe_open(model_file, framework="pt", device="cpu") as f:
metadata = f.metadata()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will we have any way or desire to warn users loading a lora from a pth file that we can't scale it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently error out when trying to save metadata to a pth file (or more generally, when use_safetensors is False).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good! i never use them. but i was concerned for others.

@sayakpaul sayakpaul requested a review from bghira August 12, 2024 13:10
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying my questions get_peft_kwargs. LGTM

Copy link
Contributor

@bghira bghira left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haven't tested it but it looks like it does the job.

@sayakpaul sayakpaul requested a review from yiyixuxu August 21, 2024 10:17
@sayakpaul
Copy link
Member Author

@yiyixuxu @DN6 could you please give this a look?

@sayakpaul
Copy link
Member Author

Closing this PR for now in light of #9236

@sayakpaul sayakpaul closed this Aug 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants