Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Fix TypeError in DreamBooth SDXL when use_dora is False #9879

Merged
merged 7 commits into from
Nov 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
convert_unet_state_dict_to_peft,
is_peft_version,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
Expand Down Expand Up @@ -1183,26 +1184,33 @@ def main(args):
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()

def get_lora_config(rank, use_dora, target_modules):
base_config = {
"r": rank,
"lora_alpha": rank,
"init_lora_weights": "gaussian",
"target_modules": target_modules,
}
if use_dora:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
base_config["use_dora"] = True

return LoraConfig(**base_config)

# now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig(
r=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
unet_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=unet_target_modules)
unet.add_adapter(unet_lora_config)

# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
text_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=text_target_modules)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)

Expand Down
Loading