diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 7309c3fc709c..1adf9203b6ec 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -28,6 +28,7 @@ USE_PEFT_BACKEND, _get_model_file, delete_adapter_layers, + get_adapter_name, is_accelerate_available, logging, set_adapter_layers, @@ -49,6 +50,10 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" +PEFT_WEIGHT_NAME = "adapter_model.bin" +PEFT_WEIGHT_NAME_SAFE = "adapter_model.safetensors" +PEFT_CONFIG_NAME = "adapter_config.json" + CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin" CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" @@ -345,7 +350,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict is_model_cpu_offload = False is_sequential_cpu_offload = False - # For PEFT backend the Unet is already offloaded at this stage as it is handled inside `lora_lora_weights_into_unet` + # For PEFT backend the Unet is already offloaded at this stage as it is handled inside `load_lora_weights_into_unet` if not USE_PEFT_BACKEND: if _pipeline is not None: for _, component in _pipeline.components.items(): @@ -823,4 +828,99 @@ def _load_ip_adapter_weights(self, state_dict): self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) self.config.encoder_hid_dim_type = "ip_image_proj" - delete_adapter_layers + def load_lora(self, pretrained_model_name_or_path: str, **kwargs): + r""" + Load LoRA checkpoints with PEFT. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for `load_lora()`.") + + from peft import PeftConfig, inject_adapter_in_model, set_peft_model_state_dict + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + adapter_name = kwargs.pop("adapter_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + allow_pickle = False + + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "load_lora_peft", + "framework": "pytorch", + } + + # Load the state dict. + model_file = None + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name or PEFT_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except IOError as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + pass + + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name or PEFT_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + + state_dict = {k.replace("base_model.model.", ""): v for k, v in state_dict.items()} + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + # Load the PEFT config. + lora_config = PeftConfig.from_pretrained(pretrained_model_name_or_path) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(self) + + # Adapter injection + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index 6d3ac8b4592a..fcd80cb99d59 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -58,7 +58,7 @@ from accelerate.utils import release_memory if is_peft_available(): - from peft import LoraConfig + from peft import LoraConfig, get_peft_model from peft.tuners.tuners_utils import BaseTunerLayer from peft.utils import get_peft_model_state_dict @@ -1393,6 +1393,72 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): } +@require_peft_backend +class UNet2DConditionModelLoRATests(unittest.TestCase): + def get_dummy_components(self): + unet_kwargs = { + "block_out_channels": (32, 64), + "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), + "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), + "cross_attention_dim": 32, + "attention_head_dim": 8, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 2, + "sample_size": 32, + } + torch.manual_seed(0) + unet = UNet2DConditionModel(**unet_kwargs) + unet_lora_config = LoraConfig( + r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False + ) + return unet, unet_lora_config + + def get_dummy_inputs(self): + batch_size = 2 + num_channels = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) + + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + def test_inference(self): + unet, unet_lora_config = self.get_dummy_components() + inputs = self.get_dummy_inputs() + outputs = unet(**inputs).sample + + unet = get_peft_model(unet, unet_lora_config) + outputs_with_lora = unet(**inputs).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + unet.save_pretrained(tmpdirname) + # `peft` stable release doesn't default to safetensors yet. + # we run checks both with `peft` main and stable release. + try: + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "adapter_model.safetensors"))) + except Exception: + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "adapter_model.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "adapter_config.json"))) + + unet, _ = self.get_dummy_components() + unet.load_lora(tmpdirname) + + has_peft_layer = any(isinstance(unet_module, BaseTunerLayer) for unet_module in unet.modules()) + assert has_peft_layer, "No PEFT layer found" + + outputs_with_lora_loaded = unet(**inputs).sample + + assert not torch.allclose( + outputs, outputs_with_lora, atol=1e-3, rtol=1e-3 + ), "LoRA layers should affect the outputs." + assert torch.allclose( + outputs_with_lora_loaded, outputs_with_lora, atol=1e-3, rtol=1e-3 + ), "Loaded LoRA layers should match the outputs." + + @slow @require_torch_gpu class LoraIntegrationTests(unittest.TestCase):