From 4c7e983bb5929320bab08d70333eeb93f047de40 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 26 Dec 2023 11:39:28 +0100 Subject: [PATCH] [Peft] fix saving / loading when unet is not "unet" (#6046) * [Peft] fix saving / loading when unet is not "unet" * Update src/diffusers/loaders/lora.py Co-authored-by: Sayak Paul * undo stablediffusion-xl changes * use unet_name to get unet for lora helpers * use unet_name --------- Co-authored-by: Sayak Paul --- src/diffusers/loaders/ip_adapter.py | 6 ++-- src/diffusers/loaders/lora.py | 46 ++++++++++++++++++----------- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 158bde436374..3df0492380e5 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -149,9 +149,11 @@ def load_ip_adapter( self.feature_extractor = CLIPImageProcessor() # load ip-adapter into unet - self.unet._load_ip_adapter_weights(state_dict) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet._load_ip_adapter_weights(state_dict) def set_ip_adapter_scale(self, scale): - for attn_processor in self.unet.attn_processors.values(): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for attn_processor in unet.attn_processors.values(): if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): attn_processor.scale = scale diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index fc50c52e412b..2ceff743daca 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -912,10 +912,10 @@ def pack_weights(layers, prefix): ) if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet")) + state_dict.update(pack_weights(unet_lora_layers, cls.unet_name)) if text_encoder_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) if transformer_lora_layers: state_dict.update(pack_weights(transformer_lora_layers, "transformer")) @@ -975,6 +975,8 @@ def unload_lora_weights(self): >>> ... ``` """ + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + if not USE_PEFT_BACKEND: if version.parse(__version__) > version.parse("0.23"): logger.warn( @@ -982,13 +984,13 @@ def unload_lora_weights(self): "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT." ) - for _, module in self.unet.named_modules(): + for _, module in unet.named_modules(): if hasattr(module, "set_lora_layer"): module.set_lora_layer(None) else: - recurse_remove_peft_layers(self.unet) - if hasattr(self.unet, "peft_config"): - del self.unet.peft_config + recurse_remove_peft_layers(unet) + if hasattr(unet, "peft_config"): + del unet.peft_config # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() @@ -1027,7 +1029,8 @@ def fuse_lora( ) if fuse_unet: - self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.fuse_lora(lora_scale, safe_fusing=safe_fusing) if USE_PEFT_BACKEND: from peft.tuners.tuners_utils import BaseTunerLayer @@ -1080,13 +1083,14 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet if unfuse_unet: if not USE_PEFT_BACKEND: - self.unet.unfuse_lora() + unet.unfuse_lora() else: from peft.tuners.tuners_utils import BaseTunerLayer - for module in self.unet.modules(): + for module in unet.modules(): if isinstance(module, BaseTunerLayer): module.unmerge() @@ -1202,8 +1206,9 @@ def set_adapters( adapter_names: Union[List[str], str], adapter_weights: Optional[List[float]] = None, ): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet # Handle the UNET - self.unet.set_adapters(adapter_names, adapter_weights) + unet.set_adapters(adapter_names, adapter_weights) # Handle the Text Encoder if hasattr(self, "text_encoder"): @@ -1216,7 +1221,8 @@ def disable_lora(self): raise ValueError("PEFT backend is required for this method.") # Disable unet adapters - self.unet.disable_lora() + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.disable_lora() # Disable text encoder adapters if hasattr(self, "text_encoder"): @@ -1229,7 +1235,8 @@ def enable_lora(self): raise ValueError("PEFT backend is required for this method.") # Enable unet adapters - self.unet.enable_lora() + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.enable_lora() # Enable text encoder adapters if hasattr(self, "text_encoder"): @@ -1251,7 +1258,8 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): adapter_names = [adapter_names] # Delete unet adapters - self.unet.delete_adapters(adapter_names) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.delete_adapters(adapter_names) for adapter_name in adapter_names: # Delete text encoder adapters @@ -1284,8 +1292,8 @@ def get_active_adapters(self) -> List[str]: from peft.tuners.tuners_utils import BaseTunerLayer active_adapters = [] - - for module in self.unet.modules(): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for module in unet.modules(): if isinstance(module, BaseTunerLayer): active_adapters = module.active_adapters break @@ -1309,8 +1317,9 @@ def get_list_adapters(self) -> Dict[str, List[str]]: if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"): set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys()) - if hasattr(self, "unet") and hasattr(self.unet, "peft_config"): - set_adapters["unet"] = list(self.unet.peft_config.keys()) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + if hasattr(self, self.unet_name) and hasattr(unet, "peft_config"): + set_adapters[self.unet_name] = list(self.unet.peft_config.keys()) return set_adapters @@ -1331,7 +1340,8 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, from peft.tuners.tuners_utils import BaseTunerLayer # Handle the UNET - for unet_module in self.unet.modules(): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for unet_module in unet.modules(): if isinstance(unet_module, BaseTunerLayer): for adapter_name in adapter_names: unet_module.lora_A[adapter_name].to(device)