-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🚨 feat: add non-breaking support to serialize metadata in loras. #9143
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice quick work. Looks almost perfect to me already :) I can't really say if all places have been adjusted or some were missed, I'll leave that to someone who has better knowledge of the diffusers code base.
Apart from that, I only have a few comments, please check.
I also noticed that there is already a mechanism for providing alphas via network_alpha_dict
. IIUC, this needs to be user provided for now and cannot be inferred automatically from the checkpoint. Still, I wonder if we can set network_alpha_dict
from config
instead of having two somewhat independent code paths to achieve the same goal.
The docstrings for config are still empty, please add them, as it's not immediately obvious what values are expected.
src/diffusers/utils/peft_utils.py
Outdated
# Try to retrive config. | ||
alpha_retrieved = False | ||
if config is not None: | ||
lora_alpha = config["lora_alpha"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this also needs to consider alpha_pattern
-- if not intended to support for now, at least raise an error if alpha_pattern
is given? If support is added, the unit test should include different alpha values (or a separate test could be added for this).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alpha_pattern
cannot be provided through get_peft_kwargs
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, maybe I'm misunderstanding, but we're passing a LoraConfig
instance, how can we know that this does not have alpha_pattern
?
I wonder if this should be:
# Try to retrieve config.
alpha_retrieved = False
if config is not None:
lora_alpha = config["lora_alpha"] if "lora_alpha" in config else lora_alpha
alpha_retrieved = True
+ if config.get("alpha_pattern", None):
+ alpha_pattern = config["alpha_pattern"]
Similar argument for rank_pattern
. In general, it's not clear to me how we should handle it if rank_pattern
/alpha_pattern
differ from rank_dict
/ network_alpha_dict
(or is it not possible)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Let me accommodate those changes and have it ready for your review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BenjaminBossan does f7d30de work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, IIUC there is still an issue. Say get_peft_kwargs
is called with both network_alpha_dict
and config
being passed. As is, network_alpha_dict
is completely ignored and alpha_pattern
remains an empty dict. I think what should happen is that either
alpha_pattern
is taken fromnetwork_alpha_dict
(same as previously)alpha_pattern
is taken fromconfig
ifconfig.alpha_pattern
is notNone
. If it isNone
,network_alpha_dict
should be used.
So either network_alpha_dict
or config.alpha_pattern
should take precedence. And if both are given, potentially warn about the one being ignored. WDYT?
Depending on the choice here, rank_pattern
should also be adjusted for consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alpha_pattern is taken from config if config.alpha_pattern is not None. If it is None, network_alpha_dict should be used.
This is not what I am doing. That is because network_alpha_dict
can only be true for non-diffusers checkpoints. For those checkpoints, metadata won't have a PEFT config.
Long story cut, short, rank_pattern
and alpha_pattern
from config (if found) will be simply ignored for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying, so it is assumed that rank_pattern
/alpha_pattern
are never passed when config
is passed, and vice versa. In that case, this could be checked and an error raised, or at least a comment added, WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add a comment. For now, a warning like we are doing now would suffice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment in 178a459.
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Thanks for suggestions, @BenjaminBossan! I have added the docstrings. LMK what you think. |
for now we've set up the default alpha to equate to rank since i was sorta confused or misunderstood what the purpose of the alpha parameter was for. as it turns out some trainers implement it a bit differently, eg. permanently scale the weights before exporting/saving. but rank=alpha works for us for now, i'll implement this when it's merged in. |
Right, thanks! |
if file_extension == SAFETENSORS_FILE_EXTENSION: | ||
with safetensors.safe_open(model_file, framework="pt", device="cpu") as f: | ||
metadata = f.metadata() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will we have any way or desire to warn users loading a lora from a pth file that we can't scale it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We currently error out when trying to save metadata to a pth file (or more generally, when use_safetensors
is False).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good! i never use them. but i was concerned for others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying my questions get_peft_kwargs
. LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haven't tested it but it looks like it does the job.
Closing this PR for now in light of #9236 |
What does this PR do?
To avoid bugs like:
In my view, it's a non-breaking change. Otherwise, training tests would have failed because of unpacking mismatches.
I attempted this back in the day: #6135. But we had to close it. Now, its necessity is evident.