Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] Improve warning messages when LoRA loading becomes a no-op #10187

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
cd88a4b
updates
sayakpaul Dec 11, 2024
b694ca4
updates
sayakpaul Dec 11, 2024
1db7503
updates
sayakpaul Dec 11, 2024
6134491
updates
sayakpaul Dec 11, 2024
db827b5
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 11, 2024
ac29785
notebooks revert
sayakpaul Dec 11, 2024
3f4a3fc
resolve conflicts
sayakpaul Dec 15, 2024
b6db978
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 19, 2024
c44f7a3
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 20, 2024
876132a
fix-copies.
sayakpaul Dec 20, 2024
bb7c09a
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 25, 2024
e6043a0
seeing
sayakpaul Dec 25, 2024
7ca7493
fix
sayakpaul Dec 25, 2024
ec44f9a
revert
sayakpaul Dec 25, 2024
343b2d2
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 25, 2024
615e372
fixes
sayakpaul Dec 25, 2024
e2e3ea0
fixes
sayakpaul Dec 25, 2024
f9dd64c
fixes
sayakpaul Dec 25, 2024
a01cb45
remove print
sayakpaul Dec 25, 2024
da96621
fix
sayakpaul Dec 25, 2024
a91138d
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 27, 2024
83ad82b
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Jan 2, 2025
be187da
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Jan 5, 2025
726e492
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Jan 7, 2025
3efdc58
fix conflicts
sayakpaul Jan 13, 2025
cf50148
conflicts ii.
sayakpaul Jan 13, 2025
b2afc10
updates
sayakpaul Jan 13, 2025
96eced3
fixes
sayakpaul Jan 13, 2025
b4be719
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Jan 13, 2025
8bf1173
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Feb 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 77 additions & 77 deletions src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,93 +339,93 @@ def _load_lora_into_text_encoder(
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = text_encoder_name if prefix is None else prefix

# Safe prefix to check with.
if any(text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
# Load the layers corresponding to text encoder and make necessary adjustments.
if prefix is not None:
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}

if len(state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
state_dict = convert_state_dict_to_diffusers(state_dict)

# convert state dict
state_dict = convert_state_dict_to_peft(state_dict)

for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]

for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]

if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}

lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)

if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")

if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")

if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)

# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)

for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]

for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]

if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}

lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)

if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")

if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)

lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)

# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)

is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=state_dict,
peft_config=lora_config,
**peft_kwargs,
)

# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)

# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)

text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />

# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
if prefix is not None and not state_dict:
logger.info(
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. This is safe to ignore if LoRA state dict didn't originally have any {text_encoder.__class__.__name__} related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
)


def _func_optionally_disable_offloading(_pipeline):
Expand Down
Loading