Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA PEFT] fix LoRA loading so that correct alphas are parsed #6135

Closed
wants to merge 32 commits into from

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Dec 11, 2023

What does this PR do?

Fixes #6087.

This PR ensures that the relevant LoraConfig is also serialized when the state dict is serialized. Otherwise, even if the LoRA state dict is passed properly with the peft backend, its underlying config might not be parsed correctly.

So, we explicitly pass the config dictionary to save_lora_weights(), and load_lora_weights() takes care of parsing the config accordingly in a backward-compatible way.

To check, first generate a peft LoRA:

from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict

unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet")
rank = 64
lora_config = LoraConfig(
    r=rank, target_modules=["to_k", "to_q", "to_v", "to_out.0"], init_lora_weights=False
)
unet.add_adapter(lora_config)

output_dir = "my_lora"
unet_lora_layers_to_save = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]

StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=unet_lora_layers_to_save, unet_lora_config=unet_lora_config)

To load:

from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipeline.load_lora_weights("my_lora")

TODO

  • Propagate to the other examples
  • Add tests
  • Document (docstrings)

Inspired by @pacman100's https://github.com/pacman100/peft-dreambooth-ui/blob/main/train_dreambooth_peft.py#L136-L212 (hence he is also a co-author here).

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation-wise, this looks good, I only have a few non-critical comments.

Regarding the overall problem, with some recent updates we made to PEFT, it should now be possible to pass lora_alpha to forward, similar to the non-PEFT LoRA layers of diffusers. So we could theoretically remove the "workaround" we have in diffusers and use the same mechanism as for non-PEFT to load and pass alphas. However, at this point, I'm not sure if it's worth it to make the change or live with the current situation.

@sayakpaul
Copy link
Member Author

Regarding the overall problem, with some recent updates we made to PEFT, it should now be possible to pass lora_alpha to forward, similar to the non-PEFT LoRA layers of diffusers. So we could theoretically remove the "workaround" we have in diffusers and use the same mechanism as for non-PEFT to load and pass alphas. However, at this point, I'm not sure if it's worth it to make the change or live with the current situation.

@BenjaminBossan thanks for your comments. If you could provide an example for me here, that'd be helpful for us to gauge if it's worthwhile to move forward with the proposed changes.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me, thanks very much @sayakpaul for leading this effort!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you could provide an example for me here, that'd be helpful for us to gauge if it's worthwhile to move forward with the proposed changes.

So what I mean here is that when the PEFT integration was added, we had to treat the scaling differently for PEFT because we couldn't pass the scale argument to PEFT layers. This led to cases like these:

if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
else:
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()

args = () if USE_PEFT_BACKEND else (scale,)

Now, in PEFT, we allow to pass arbitrary extra arguments to forward and the PEFT layer will pass it on to the base layer that's being adapted, as e.g. here:

https://github.com/huggingface/peft/blob/997e6ec5ab4bfbfbb79d13a7e51c8fd3874635fa/src/peft/tuners/lora/layer.py#L364-L374

Therefore, I think we could remove all the workarounds we have where we have to scale and unscale the PEFT layer weights.

That said, I haven't tested it and it would be a lot of work to unwind these changes. Maybe it's just not worth it and we can keep the current implementation, as it works.

@sayakpaul
Copy link
Member Author

@BenjaminBossan agree with your assessment here. Also, I think I have addressed your comments. Let me know.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Looks all good from my side.

@sayakpaul sayakpaul marked this pull request as ready for review December 15, 2023 12:54
@sayakpaul sayakpaul changed the title [WIP][LoRA PEFT] fix LoRA loading so that correct alphas are parsed [LoRA PEFT] fix LoRA loading so that correct alphas are parsed Dec 15, 2023
@sayakpaul
Copy link
Member Author

@younesbelkada @BenjaminBossan @patrickvonplaten would appreciate another round of review here.

Benjamin, Feel free to ignore the changes introduced in the training-related parts.

I would wait for the three of you to approve this PR as this change is quite impactful.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @sayakpaul, great work on fixing the loading of LoRA weights with PEFT training support. This enables the status quo of having a single safetensors weight file while correctly saving the related config. LGTM 🔥🚀✨! Also, Thank you for taking the time to go through my suggestions and related code.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impressive and inspiring work @sayakpaul ! Thanks for taking the lead on this !
One minor suggestion would be to detail this approach in the doc for people that are curious about how peft configs are saved internally in the safetensor checkpoints, what do you think?

@sayakpaul
Copy link
Member Author

One minor suggestion would be to detail this approach in the doc for people that are curious about how peft configs are saved internally in the safetensor checkpoints, what do you think?

I have edited the description of the PR for the community to refer to the details. I think that might be better. WDYT @younesbelkada?

@younesbelkada
Copy link
Contributor

sounds great @sayakpaul thanks!

@sayakpaul
Copy link
Member Author

Closing this PR to favor a simpler alternative as described below. Simplified PR: #6225.

We didn't have a concept of alpha in our non-peft diffusers LoRA training scripts. Hence, it used to be set to None, therefore leading to no impact.

But that is NOT the case with peft (especially with how lora_alpha is initialized after LoraConfig is initialized) as per #6087.

So, @pacman100 suggested a simpler alternative. Just set the lora_alpha within LoraConfig to args.rank and that should cut it for us. Even though this is a simpler alternative and does the job all of us (@younesbelkada @BenjaminBossan @pacman100 and myself) agree that this way we're restricting the users to not benefit from lora_alpha. We'll see how it goes and if there are requests from the community we can always refer to this PR (will keep the branch alive for reference).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

diffusers doesn't save and load the LoraConfig, resulting wrong lora_alpha during inference.
5 participants