Skip to content

Commit

Permalink
[LoRA deprecation] handle rest of the stuff related to deprecated lor…
Browse files Browse the repository at this point in the history
…a stuff. (huggingface#6426)

* handle rest of the stuff related to deprecated lora stuff.

* fix: copies

* don't modify the uNet in-place.

* fix: temporal autoencoder.

* manually remove lora layers.

* don't copy unet.

* alright

* remove lora attn processors from unet3d

* fix: unet3d.

* styl

* Empty-Commit
  • Loading branch information
sayakpaul authored and Jimmy committed Apr 26, 2024
1 parent 58c5aa8 commit e09dd2a
Show file tree
Hide file tree
Showing 16 changed files with 83 additions and 101 deletions.
6 changes: 2 additions & 4 deletions examples/research_projects/controlnetxs/controlnetxs.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,7 @@ def attn_processors(self) -> Dict[str, AttentionProcessor]:
"""
return self.control_model.attn_processors

def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -509,7 +507,7 @@ def set_attn_processor(
processor. This is strongly recommended when setting trainable attention processors.
"""
self.control_model.set_attn_processor(processor, _remove_lora)
self.control_model.set_attn_processor(processor)

def set_default_attn_processor(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ def unload_lora_weights(self):

if not USE_PEFT_BACKEND:
if version.parse(__version__) > version.parse("0.23"):
logger.warn(
logger.warning(
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
)
Expand Down
17 changes: 1 addition & 16 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,29 +373,14 @@ def set_attention_slice(self, slice_size: int) -> None:

self.set_processor(processor)

def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None:
def set_processor(self, processor: "AttnProcessor") -> None:
r"""
Set the attention processor to use.
Args:
processor (`AttnProcessor`):
The attention processor to use.
_remove_lora (`bool`, *optional*, defaults to `False`):
Set to `True` to remove LoRA layers from the model.
"""
if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
deprecate(
"set_processor to offload LoRA",
"0.26.0",
"In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
)
# TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
# We need to remove all LoRA layers
# Don't forget to remove ALL `_remove_lora` from the codebase
for module in self.modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)

# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if (
Expand Down
10 changes: 4 additions & 6 deletions src/diffusers/models/autoencoders/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
return processors

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -208,9 +206,9 @@ def set_attn_processor(
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
module.set_processor(processor.pop(f"{name}.processor"))

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand All @@ -232,7 +230,7 @@ def set_default_attn_processor(self):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor, _remove_lora=True)
self.set_attn_processor(processor)

@apply_forward_hook
def encode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
return processors

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -293,9 +291,9 @@ def set_attn_processor(
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
module.set_processor(processor.pop(f"{name}.processor"))

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand All @@ -314,7 +312,7 @@ def set_default_attn_processor(self):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor, _remove_lora=True)
self.set_attn_processor(processor)

@apply_forward_hook
def encode(
Expand Down
10 changes: 4 additions & 6 deletions src/diffusers/models/autoencoders/consistency_decoder_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
return processors

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -238,9 +236,9 @@ def set_attn_processor(
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
module.set_processor(processor.pop(f"{name}.processor"))

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand All @@ -262,7 +260,7 @@ def set_default_attn_processor(self):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor, _remove_lora=True)
self.set_attn_processor(processor)

@apply_forward_hook
def encode(
Expand Down
10 changes: 4 additions & 6 deletions src/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
return processors

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -560,9 +558,9 @@ def set_attn_processor(
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
module.set_processor(processor.pop(f"{name}.processor"))

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand All @@ -584,7 +582,7 @@ def set_default_attn_processor(self):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor, _remove_lora=True)
self.set_attn_processor(processor)

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
Expand Down
10 changes: 4 additions & 6 deletions src/diffusers/models/prior_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
return processors

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -218,9 +216,9 @@ def set_attn_processor(
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
module.set_processor(processor.pop(f"{name}.processor"))

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand All @@ -242,7 +240,7 @@ def set_default_attn_processor(self):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor, _remove_lora=True)
self.set_attn_processor(processor)

def forward(
self,
Expand Down
10 changes: 4 additions & 6 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,9 +643,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:

return processors

def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -669,9 +667,9 @@ def set_attn_processor(
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
module.set_processor(processor.pop(f"{name}.processor"))

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand All @@ -692,7 +690,7 @@ def set_default_attn_processor(self):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor, _remove_lora=True)
self.set_attn_processor(processor)

def set_attention_slice(self, slice_size):
r"""
Expand Down
10 changes: 4 additions & 6 deletions src/diffusers/models/unet_3d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
fn_recursive_set_attention_slice(module, reversed_slice_size)

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -401,9 +399,9 @@ def set_attn_processor(
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
module.set_processor(processor.pop(f"{name}.processor"))

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand Down Expand Up @@ -465,7 +463,7 @@ def set_default_attn_processor(self):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor, _remove_lora=True)
self.set_attn_processor(processor)

def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
Expand Down
10 changes: 4 additions & 6 deletions src/diffusers/models/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,9 +549,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
return processors

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -575,9 +573,9 @@ def set_attn_processor(
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
module.set_processor(processor.pop(f"{name}.processor"))

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand Down Expand Up @@ -641,7 +639,7 @@ def set_default_attn_processor(self) -> None:
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor, _remove_lora=True)
self.set_attn_processor(processor)

def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
Expand Down
10 changes: 4 additions & 6 deletions src/diffusers/models/uvit_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
return processors

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -263,9 +261,9 @@ def set_attn_processor(
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
module.set_processor(processor.pop(f"{name}.processor"))

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand All @@ -287,7 +285,7 @@ def set_default_attn_processor(self):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor, _remove_lora=True)
self.set_attn_processor(processor)


class UVit2DConvEmbed(nn.Module):
Expand Down
10 changes: 4 additions & 6 deletions src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
return processors

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -564,9 +562,9 @@ def set_attn_processor(
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
module.set_processor(processor.pop(f"{name}.processor"))

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand All @@ -588,7 +586,7 @@ def set_default_attn_processor(self):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor, _remove_lora=True)
self.set_attn_processor(processor)

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
Expand Down
Loading

0 comments on commit e09dd2a

Please sign in to comment.