-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA PEFT] fix LoRA loading so that correct alphas are parsed #6135
Conversation
Co-authored-by: pacman100 <13534540+pacman100@users.noreply.github.com>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementation-wise, this looks good, I only have a few non-critical comments.
Regarding the overall problem, with some recent updates we made to PEFT, it should now be possible to pass lora_alpha
to forward
, similar to the non-PEFT LoRA layers of diffusers. So we could theoretically remove the "workaround" we have in diffusers and use the same mechanism as for non-PEFT to load and pass alphas. However, at this point, I'm not sure if it's worth it to make the change or live with the current situation.
@BenjaminBossan thanks for your comments. If you could provide an example for me here, that'd be helpful for us to gauge if it's worthwhile to move forward with the proposed changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me, thanks very much @sayakpaul for leading this effort!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you could provide an example for me here, that'd be helpful for us to gauge if it's worthwhile to move forward with the proposed changes.
So what I mean here is that when the PEFT integration was added, we had to treat the scaling differently for PEFT because we couldn't pass the scale
argument to PEFT layers. This led to cases like these:
diffusers/src/diffusers/models/transformer_2d.py
Lines 404 to 417 in 93ea26f
if not self.use_linear_projection: | |
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() | |
hidden_states = ( | |
self.proj_out(hidden_states, scale=lora_scale) | |
if not USE_PEFT_BACKEND | |
else self.proj_out(hidden_states) | |
) | |
else: | |
hidden_states = ( | |
self.proj_out(hidden_states, scale=lora_scale) | |
if not USE_PEFT_BACKEND | |
else self.proj_out(hidden_states) | |
) | |
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() |
args = () if USE_PEFT_BACKEND else (scale,) |
Now, in PEFT, we allow to pass arbitrary extra arguments to forward
and the PEFT layer will pass it on to the base layer that's being adapted, as e.g. here:
Therefore, I think we could remove all the workarounds we have where we have to scale and unscale the PEFT layer weights.
That said, I haven't tested it and it would be a lot of work to unwind these changes. Maybe it's just not worth it and we can keep the current implementation, as it works.
@BenjaminBossan agree with your assessment here. Also, I think I have addressed your comments. Let me know. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Looks all good from my side.
@younesbelkada @BenjaminBossan @patrickvonplaten would appreciate another round of review here. Benjamin, Feel free to ignore the changes introduced in the training-related parts. I would wait for the three of you to approve this PR as this change is quite impactful. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @sayakpaul, great work on fixing the loading of LoRA weights with PEFT training support. This enables the status quo of having a single safetensors weight file while correctly saving the related config. LGTM 🔥🚀✨! Also, Thank you for taking the time to go through my suggestions and related code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Impressive and inspiring work @sayakpaul ! Thanks for taking the lead on this !
One minor suggestion would be to detail this approach in the doc for people that are curious about how peft configs are saved internally in the safetensor checkpoints, what do you think?
I have edited the description of the PR for the community to refer to the details. I think that might be better. WDYT @younesbelkada? |
sounds great @sayakpaul thanks! |
Closing this PR to favor a simpler alternative as described below. Simplified PR: #6225. We didn't have a concept of But that is NOT the case with So, @pacman100 suggested a simpler alternative. Just set the |
What does this PR do?
Fixes #6087.
This PR ensures that the relevant
LoraConfig
is also serialized when the state dict is serialized. Otherwise, even if the LoRA state dict is passed properly with thepeft
backend, its underlying config might not be parsed correctly.So, we explicitly pass the config dictionary to
save_lora_weights()
, andload_lora_weights()
takes care of parsing the config accordingly in a backward-compatible way.To check, first generate a
peft
LoRA:To load:
TODO
Inspired by @pacman100's https://github.com/pacman100/peft-dreambooth-ui/blob/main/train_dreambooth_peft.py#L136-L212 (hence he is also a co-author here).