From 2a9064c46969f136826beb01f8e04dbbf7da2300 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 2 Jan 2024 18:14:04 +0530 Subject: [PATCH] [LoRA] Remove the use of depcrecated loRA functionalities such as `LoRAAttnProcessor` (#6369) * start deprecating loraattn. * fix * wrap into unet_lora_state_dict * utilize text_encoder_lora_params * utilize text_encoder_attn_modules * debug * debug * remove print * don't use text encoder for test_stable_diffusion_lora * load the procs. * set_default_attn_processor * fix: set_default_attn_processor call. * fix: lora_components[unet_lora_params] * checking for 3d. * 3d. * more fixes. * debug * debug * debug * debug * more debug * more debug * more debug * more debug * more debug * more debug * hack. * remove comments and prep for a PR. * appropriate set_lora_weights() * fix * fix: test_unload_lora_sd * fix: test_unload_lora_sd * use dfault attebtion processors. * debu * debug nan * debug nan * debug nan * use NaN instead of inf * remove comments. * fix: test_text_encoder_lora_state_dict_unchanged * attention processor default * default attention processors. * default * style --- tests/lora/test_lora_layers_old_backend.py | 921 ++++++++------------- 1 file changed, 335 insertions(+), 586 deletions(-) diff --git a/tests/lora/test_lora_layers_old_backend.py b/tests/lora/test_lora_layers_old_backend.py index 3d3b858fa0fdd..7d6d30169455c 100644 --- a/tests/lora/test_lora_layers_old_backend.py +++ b/tests/lora/test_lora_layers_old_backend.py @@ -22,7 +22,6 @@ import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from huggingface_hub.repocard import RepoCard from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer @@ -41,17 +40,15 @@ UNet2DConditionModel, UNet3DConditionModel, ) -from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from diffusers.loaders import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin from diffusers.models.attention_processor import ( Attention, AttnProcessor, AttnProcessor2_0, - LoRAAttnProcessor, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) -from diffusers.models.lora import PatchedLoraProjection, text_encoder_attn_modules +from diffusers.models.lora import LoRALinearLayer +from diffusers.training_utils import unet_lora_state_dict from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( deprecate_after_peft_backend, @@ -64,118 +61,178 @@ ) -def create_lora_layers(model, mock_weights: bool = True): - lora_attn_procs = {} - for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] +def text_encoder_attn_modules(text_encoder): + attn_modules = [] - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - lora_attn_procs[name] = lora_attn_procs[name].to(model.device) + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + else: + raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") - if mock_weights: - # add 1 to weights to mock trained weights - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 + return attn_modules - return lora_attn_procs +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} -def create_unet_lora_layers(unet: nn.Module): - lora_attn_procs = {} - for name in unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = unet.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = unet.config.block_out_channels[block_id] - lora_attn_processor_class = ( - LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor - ) - lora_attn_procs[name] = lora_attn_processor_class( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim - ) - unet_lora_layers = AttnProcsLayers(lora_attn_procs) - return lora_attn_procs, unet_lora_layers + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v -def create_text_encoder_lora_attn_procs(text_encoder: nn.Module): - text_lora_attn_procs = {} - lora_attn_processor_class = ( - LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor - ) - for name, module in text_encoder_attn_modules(text_encoder): - if isinstance(module.out_proj, nn.Linear): - out_features = module.out_proj.out_features - elif isinstance(module.out_proj, PatchedLoraProjection): - out_features = module.out_proj.regular_linear_layer.out_features - else: - assert False, module.out_proj.__class__ + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v - text_lora_attn_procs[name] = lora_attn_processor_class(hidden_size=out_features, cross_attention_dim=None) - return text_lora_attn_procs + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + return state_dict -def create_text_encoder_lora_layers(text_encoder: nn.Module): - text_lora_attn_procs = create_text_encoder_lora_attn_procs(text_encoder) - text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) - return text_encoder_lora_layers +def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): + unet_lora_parameters = [] -def create_lora_3d_layers(model, mock_weights: bool = True): - lora_attn_procs = {} - for name in model.attn_processors.keys(): - has_cross_attention = name.endswith("attn2.processor") and not ( - name.startswith("transformer_in") or "temp_attentions" in name.split(".") + for attn_processor_name, attn_processor in unet.attn_processors.items(): + # Parse the attention module. + attn_module = unet + for n in attn_processor_name.split(".")[:-1]: + attn_module = getattr(attn_module, n) + + # Set the `lora_layer` attribute of the attention-related matrices. + attn_module.to_q.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_q.in_features, + out_features=attn_module.to_q.out_features, + rank=rank, + ) ) - cross_attention_dim = model.config.cross_attention_dim if has_cross_attention else None - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - elif name.startswith("transformer_in"): - # Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148 - hidden_size = 8 * model.config.attention_head_dim + attn_module.to_k.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_k.in_features, + out_features=attn_module.to_k.out_features, + rank=rank, + ) + ) + attn_module.to_v.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_v.in_features, + out_features=attn_module.to_v.out_features, + rank=rank, + ) + ) + attn_module.to_out[0].set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_out[0].in_features, + out_features=attn_module.to_out[0].out_features, + rank=rank, + ) + ) + + if mock_weights: + with torch.no_grad(): + attn_module.to_q.lora_layer.up.weight += 1 + attn_module.to_k.lora_layer.up.weight += 1 + attn_module.to_v.lora_layer.up.weight += 1 + attn_module.to_out[0].lora_layer.up.weight += 1 + + unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - lora_attn_procs[name] = lora_attn_procs[name].to(model.device) + return unet_lora_parameters, unet_lora_state_dict(unet) + + +def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): + for attn_processor_name in unet.attn_processors.keys(): + has_cross_attention = attn_processor_name.endswith("attn2.processor") and not ( + attn_processor_name.startswith("transformer_in") or "temp_attentions" in attn_processor_name.split(".") + ) + cross_attention_dim = unet.config.cross_attention_dim if has_cross_attention else None + + if attn_processor_name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif attn_processor_name.startswith("up_blocks"): + block_id = int(attn_processor_name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif attn_processor_name.startswith("down_blocks"): + block_id = int(attn_processor_name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + elif attn_processor_name.startswith("transformer_in"): + # Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148 + hidden_size = 8 * unet.config.attention_head_dim + + # Parse the attention module. + attn_module = unet + for n in attn_processor_name.split(".")[:-1]: + attn_module = getattr(attn_module, n) + + attn_module.to_q.set_lora_layer( + LoRALinearLayer( + in_features=min(attn_module.to_q.in_features, hidden_size), + out_features=attn_module.to_q.out_features + if cross_attention_dim is None + else max(attn_module.to_q.out_features, cross_attention_dim), + rank=rank, + ) + ) + attn_module.to_k.set_lora_layer( + LoRALinearLayer( + in_features=min(attn_module.to_k.in_features, hidden_size), + out_features=attn_module.to_k.out_features + if cross_attention_dim is None + else max(attn_module.to_k.out_features, cross_attention_dim), + rank=rank, + ) + ) + attn_module.to_v.set_lora_layer( + LoRALinearLayer( + in_features=min(attn_module.to_v.in_features, hidden_size), + out_features=attn_module.to_v.out_features + if cross_attention_dim is None + else max(attn_module.to_v.out_features, cross_attention_dim), + rank=rank, + ) + ) + attn_module.to_out[0].set_lora_layer( + LoRALinearLayer( + in_features=min(attn_module.to_out[0].in_features, hidden_size), + out_features=attn_module.to_out[0].out_features + if cross_attention_dim is None + else max(attn_module.to_out[0].out_features, cross_attention_dim), + rank=rank, + ) + ) if mock_weights: - # add 1 to weights to mock trained weights with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 + attn_module.to_q.lora_layer.up.weight += 1 + attn_module.to_k.lora_layer.up.weight += 1 + attn_module.to_v.lora_layer.up.weight += 1 + attn_module.to_out[0].lora_layer.up.weight += 1 - return lora_attn_procs + return unet_lora_state_dict(unet) def set_lora_weights(lora_attn_parameters, randn_weight=False, var=1.0): - with torch.no_grad(): - for parameter in lora_attn_parameters: - if randn_weight: - parameter[:] = torch.randn_like(parameter) * var - else: - torch.zero_(parameter) + if not isinstance(lora_attn_parameters, dict): + with torch.no_grad(): + for parameter in lora_attn_parameters: + if randn_weight: + parameter[:] = torch.randn_like(parameter) * var + else: + torch.zero_(parameter) + else: + if randn_weight: + modified_state_dict = {k: torch.rand_like(v) * var for k, v in lora_attn_parameters.items()} + else: + modified_state_dict = {k: torch.zeros_like(v) * var for k, v in lora_attn_parameters.items()} + return modified_state_dict def state_dicts_almost_equal(sd1, sd2): @@ -192,6 +249,8 @@ def state_dicts_almost_equal(sd1, sd2): @deprecate_after_peft_backend class LoraLoaderMixinTests(unittest.TestCase): + lora_rank = 4 + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( @@ -235,8 +294,13 @@ def get_dummy_components(self): text_encoder = CLIPTextModel(text_encoder_config) tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) - text_encoder_lora_layers = create_text_encoder_lora_layers(text_encoder) + unet_lora_raw_params, unet_lora_params = create_unet_lora_layers(unet, rank=self.lora_rank) + text_encoder_lora_params = LoraLoaderMixin._modify_text_encoder( + text_encoder, dtype=torch.float32, rank=self.lora_rank + ) + text_encoder_lora_params = set_lora_weights( + text_encoder_lora_state_dict(text_encoder), randn_weight=True, var=0.1 + ) pipeline_components = { "unet": unet, @@ -249,9 +313,9 @@ def get_dummy_components(self): "image_encoder": None, } lora_components = { - "unet_lora_layers": unet_lora_layers, - "text_encoder_lora_layers": text_encoder_lora_layers, - "unet_lora_attn_procs": unet_lora_attn_procs, + "unet_lora_raw_params": unet_lora_raw_params, + "unet_lora_params": unet_lora_params, + "text_encoder_lora_params": text_encoder_lora_params, } return pipeline_components, lora_components @@ -290,8 +354,8 @@ def create_lora_weight_file(self, tmpdirname): _, lora_components = self.get_dummy_components() LoraLoaderMixin.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -311,71 +375,12 @@ def test_stable_diffusion_xformers_attn_processors(self): image = sd_pipe(**inputs).images assert image.shape == (1, 64, 64, 3) - @unittest.skipIf(not torch.cuda.is_available(), reason="xformers requires cuda") - def test_stable_diffusion_attn_processors(self): - # disable_full_determinism() - device = "cuda" # ensure determinism for the device-dependent torch.Generator - components, _ = self.get_dummy_components() - sd_pipe = StableDiffusionPipeline(**components) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs() - - # run normal sd pipe - image = sd_pipe(**inputs).images - assert image.shape == (1, 64, 64, 3) - - # run attention slicing - sd_pipe.enable_attention_slicing() - image = sd_pipe(**inputs).images - assert image.shape == (1, 64, 64, 3) - - # run vae attention slicing - sd_pipe.enable_vae_slicing() - image = sd_pipe(**inputs).images - assert image.shape == (1, 64, 64, 3) - - # run lora attention - attn_processors, _ = create_unet_lora_layers(sd_pipe.unet) - attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()} - sd_pipe.unet.set_attn_processor(attn_processors) - image = sd_pipe(**inputs).images - assert image.shape == (1, 64, 64, 3) - - @unittest.skipIf(not torch.cuda.is_available() or not is_xformers_available(), reason="xformers requires cuda") - def test_stable_diffusion_set_xformers_attn_processors(self): - # disable_full_determinism() - device = "cuda" # ensure determinism for the device-dependent torch.Generator - components, _ = self.get_dummy_components() - sd_pipe = StableDiffusionPipeline(**components) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs() - - # run normal sd pipe - image = sd_pipe(**inputs).images - assert image.shape == (1, 64, 64, 3) - - # run lora xformers attention - attn_processors, _ = create_unet_lora_layers(sd_pipe.unet) - attn_processors = { - k: LoRAXFormersAttnProcessor(hidden_size=v.hidden_size, cross_attention_dim=v.cross_attention_dim) - for k, v in attn_processors.items() - } - attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()} - sd_pipe.unet.set_attn_processor(attn_processors) - image = sd_pipe(**inputs).images - assert image.shape == (1, 64, 64, 3) - - # enable_full_determinism() - def test_stable_diffusion_lora(self): - components, _ = self.get_dummy_components() + components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionPipeline(**components) sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) + sd_pipe.unet.set_default_attn_processor() # forward 1 _, _, inputs = self.get_dummy_inputs() @@ -385,9 +390,7 @@ def test_stable_diffusion_lora(self): image_slice = image[0, -3:, -3:, -1] # set lora layers - lora_attn_procs = create_lora_layers(sd_pipe.unet) - sd_pipe.unet.set_attn_processor(lora_attn_procs) - sd_pipe = sd_pipe.to(torch_device) + sd_pipe.unet.load_attn_procs(lora_components["unet_lora_params"]) # forward 2 _, _, inputs = self.get_dummy_inputs() @@ -420,8 +423,8 @@ def test_lora_save_load(self): with tempfile.TemporaryDirectory() as tmpdirname: LoraLoaderMixin.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) sd_pipe.load_lora_weights(tmpdirname) @@ -434,7 +437,6 @@ def test_lora_save_load(self): def test_lora_save_load_no_safe_serialization(self): pipeline_components, lora_components = self.get_dummy_components() - unet_lora_attn_procs = lora_components["unet_lora_attn_procs"] sd_pipe = StableDiffusionPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) @@ -445,9 +447,13 @@ def test_lora_save_load_no_safe_serialization(self): orig_image_slice = original_images[0, -3:, -3:, -1] with tempfile.TemporaryDirectory() as tmpdirname: - unet = sd_pipe.unet - unet.set_attn_processor(unet_lora_attn_procs) - unet.save_attn_procs(tmpdirname, safe_serialization=False) + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + safe_serialization=False, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) sd_pipe.load_lora_weights(tmpdirname) @@ -468,9 +474,18 @@ def test_text_encoder_lora_monkey_patch(self): assert outputs_without_lora.shape == (1, 77, 32) # monkey patch - params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) - - set_lora_weights(params, randn_weight=False) + text_encoder_lora_params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) + text_encoder_lora_params = set_lora_weights( + text_encoder_lora_state_dict(pipe.text_encoder), randn_weight=False + ) + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=None, + text_encoder_lora_layers=text_encoder_lora_params, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + pipe.load_lora_weights(tmpdirname) # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] @@ -480,13 +495,22 @@ def test_text_encoder_lora_monkey_patch(self): outputs_without_lora, outputs_with_lora ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" - # create lora_attn_procs with randn up.weights - create_text_encoder_lora_attn_procs(pipe.text_encoder) - # monkey patch - params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) + pipeline_components, _ = self.get_dummy_components() + pipe = StableDiffusionPipeline(**pipeline_components) - set_lora_weights(params, randn_weight=True) + text_encoder_lora_params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) + text_encoder_lora_params = set_lora_weights( + text_encoder_lora_state_dict(pipe.text_encoder), randn_weight=True, var=0.1 + ) + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=None, + text_encoder_lora_layers=text_encoder_lora_params, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + pipe.load_lora_weights(tmpdirname) # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] @@ -508,8 +532,15 @@ def test_text_encoder_lora_remove_monkey_patch(self): # monkey patch params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) - - set_lora_weights(params, randn_weight=True) + params = set_lora_weights(text_encoder_lora_state_dict(pipe.text_encoder), var=0.1, randn_weight=True) + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=None, + text_encoder_lora_layers=params, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + pipe.load_lora_weights(tmpdirname) # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] @@ -541,8 +572,8 @@ def test_text_encoder_lora_scale(self): with tempfile.TemporaryDirectory() as tmpdirname: LoraLoaderMixin.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) sd_pipe.load_lora_weights(tmpdirname) @@ -587,19 +618,16 @@ def test_unload_lora_sd(self): pipeline_components, lora_components = self.get_dummy_components() _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe.unet.set_default_attn_processor() original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images orig_image_slice = original_images[0, -3:, -3:, -1] - # Emulate training. - set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_lora_layers"].parameters(), randn_weight=True) - with tempfile.TemporaryDirectory() as tmpdirname: LoraLoaderMixin.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) sd_pipe.load_lora_weights(tmpdirname) @@ -677,7 +705,7 @@ def test_lora_save_load_with_xformers(self): with tempfile.TemporaryDirectory() as tmpdirname: LoraLoaderMixin.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -691,7 +719,9 @@ def test_lora_save_load_with_xformers(self): @deprecate_after_peft_backend -class SDXInpaintLoraMixinTests(unittest.TestCase): +class SDInpaintLoraMixinTests(unittest.TestCase): + lora_rank = 4 + def get_dummy_inputs(self, device, seed=0, img_res=64, output_pil=True): # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched if output_pil: @@ -765,6 +795,14 @@ def get_dummy_components(self): text_encoder = CLIPTextModel(text_encoder_config) tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + unet_lora_raw_params, unet_lora_params = create_unet_lora_layers(unet, rank=self.lora_rank) + text_encoder_lora_params = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( + text_encoder, dtype=torch.float32, rank=self.lora_rank + ) + text_encoder_lora_params = set_lora_weights( + text_encoder_lora_state_dict(text_encoder), randn_weight=True, var=0.1 + ) + components = { "unet": unet, "scheduler": scheduler, @@ -775,15 +813,21 @@ def get_dummy_components(self): "feature_extractor": None, "image_encoder": None, } - return components + lora_components = { + "unet_lora_raw_params": unet_lora_raw_params, + "unet_lora_params": unet_lora_params, + "text_encoder_lora_params": text_encoder_lora_params, + } + return components, lora_components def test_stable_diffusion_inpaint_lora(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator - components = self.get_dummy_components() + components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionInpaintPipeline(**components) sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) + sd_pipe.unet.set_default_attn_processor() # forward 1 inputs = self.get_dummy_inputs(device) @@ -792,9 +836,7 @@ def test_stable_diffusion_inpaint_lora(self): image_slice = image[0, -3:, -3:, -1] # set lora layers - lora_attn_procs = create_lora_layers(sd_pipe.unet) - sd_pipe.unet.set_attn_processor(lora_attn_procs) - sd_pipe = sd_pipe.to(torch_device) + sd_pipe.unet.load_attn_procs(lora_components["unet_lora_params"]) # forward 2 inputs = self.get_dummy_inputs(device) @@ -814,7 +856,9 @@ def test_stable_diffusion_inpaint_lora(self): @deprecate_after_peft_backend class SDXLLoraLoaderMixinTests(unittest.TestCase): - def get_dummy_components(self): + lora_rank = 4 + + def get_dummy_components(self, modify_text_encoder=True): torch.manual_seed(0) unet = UNet2DConditionModel( block_out_channels=(32, 64), @@ -871,9 +915,24 @@ def get_dummy_components(self): text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) - text_encoder_one_lora_layers = create_text_encoder_lora_layers(text_encoder) - text_encoder_two_lora_layers = create_text_encoder_lora_layers(text_encoder_2) + _, unet_lora_params = create_unet_lora_layers(unet, rank=self.lora_rank) + + if modify_text_encoder: + text_encoder_lora_params = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( + text_encoder, dtype=torch.float32, rank=self.lora_rank + ) + text_encoder_lora_params = set_lora_weights( + text_encoder_lora_state_dict(text_encoder), randn_weight=True, var=0.1 + ) + text_encoder_two_lora_params = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( + text_encoder_2, dtype=torch.float32, rank=self.lora_rank + ) + text_encoder_two_lora_params = set_lora_weights( + text_encoder_lora_state_dict(text_encoder_2), randn_weight=True, var=0.1 + ) + else: + text_encoder_lora_params = None + text_encoder_two_lora_params = None pipeline_components = { "unet": unet, @@ -887,10 +946,9 @@ def get_dummy_components(self): "feature_extractor": None, } lora_components = { - "unet_lora_layers": unet_lora_layers, - "text_encoder_one_lora_layers": text_encoder_one_lora_layers, - "text_encoder_two_lora_layers": text_encoder_two_lora_layers, - "unet_lora_attn_procs": unet_lora_attn_procs, + "unet_lora_params": unet_lora_params, + "text_encoder_lora_params": text_encoder_lora_params, + "text_encoder_two_lora_params": text_encoder_two_lora_params, } return pipeline_components, lora_components @@ -929,9 +987,9 @@ def test_lora_save_load(self): with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) sd_pipe.load_lora_weights(tmpdirname) @@ -946,21 +1004,17 @@ def test_unload_lora_sdxl(self): pipeline_components, lora_components = self.get_dummy_components() _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe.unet.set_default_attn_processor() original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images orig_image_slice = original_images[0, -3:, -3:, -1] - # Emulate training. - set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) - with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) sd_pipe.load_lora_weights(tmpdirname) @@ -992,9 +1046,9 @@ def test_load_lora_locally(self): with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], safe_serialization=False, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) @@ -1003,7 +1057,7 @@ def test_load_lora_locally(self): sd_pipe.unload_lora_weights() def test_text_encoder_lora_state_dict_unchanged(self): - pipeline_components, lora_components = self.get_dummy_components() + pipeline_components, lora_components = self.get_dummy_components(modify_text_encoder=False) sd_pipe = StableDiffusionXLPipeline(**pipeline_components) text_encoder_1_sd_keys = sorted(sd_pipe.text_encoder.state_dict().keys()) @@ -1012,12 +1066,26 @@ def test_text_encoder_lora_state_dict_unchanged(self): sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) + # Modify the text encoder. + _ = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( + sd_pipe.text_encoder, dtype=torch.float32, rank=self.lora_rank + ) + lora_components["text_encoder_lora_params"] = set_lora_weights( + text_encoder_lora_state_dict(sd_pipe.text_encoder), randn_weight=True, var=0.1 + ) + _ = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( + sd_pipe.text_encoder_2, dtype=torch.float32, rank=self.lora_rank + ) + lora_components["text_encoder_two_lora_params"] = set_lora_weights( + text_encoder_lora_state_dict(sd_pipe.text_encoder_2), randn_weight=True, var=0.1 + ) + with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], safe_serialization=False, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) @@ -1050,9 +1118,9 @@ def test_load_lora_locally_safetensors(self): with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -1066,19 +1134,12 @@ def test_lora_fuse_nan(self): sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) - _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) - - # Emulate training. - set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) - with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -1087,7 +1148,7 @@ def test_lora_fuse_nan(self): # corrupt one LoRA weight with `inf` values with torch.no_grad(): sd_pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_layer.down.weight += float( - "inf" + "NaN" ) # with `safe_fusing=True` we should see an Error @@ -1112,17 +1173,12 @@ def test_lora_fusion(self): original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images orig_image_slice = original_images[0, -3:, -3:, -1] - # Emulate training. - set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) - with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -1139,23 +1195,19 @@ def test_unfuse_lora(self): sd_pipe = StableDiffusionXLPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) + sd_pipe.unet.set_default_attn_processor() _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images orig_image_slice = original_images[0, -3:, -3:, -1] - # Emulate training. - set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) - with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -1190,17 +1242,12 @@ def test_lora_fusion_is_not_affected_by_unloading(self): _ = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - # Emulate training. - set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) - with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -1229,17 +1276,12 @@ def test_fuse_lora_with_different_scales(self): _ = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - # Emulate training. - set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) - with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -1255,9 +1297,9 @@ def test_fuse_lora_with_different_scales(self): with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -1276,22 +1318,18 @@ def test_with_different_scales(self): sd_pipe = StableDiffusionXLPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) + sd_pipe.unet.set_default_attn_processor() _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images original_imagee_slice = original_images[0, -3:, -3:, -1] - # Emulate training. - set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) - with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -1323,23 +1361,19 @@ def test_with_different_scales_fusion_equivalence(self): sd_pipe = StableDiffusionXLPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) + sd_pipe.unet.set_default_attn_processor() _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images images_slice = images[0, -3:, -3:, -1] - # Emulate training. - set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True, var=0.1) - set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True, var=0.1) - set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True, var=0.1) - with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -1376,17 +1410,12 @@ def test_save_load_fused_lora_modules(self): _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) - # Emulate training. - set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True, var=0.1) - set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True, var=0.1) - set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True, var=0.1) - with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -1460,10 +1489,10 @@ def test_lora_processors(self): with torch.no_grad(): sample1 = model(**inputs_dict).sample - lora_attn_procs = create_lora_layers(model) + _, lora_params = create_unet_lora_layers(model) # make sure we can set a list of attention processors - model.set_attn_processor(lora_attn_procs) + model.load_attn_procs(lora_params) model.to(torch_device) # test that attn processors can be set to itself @@ -1480,120 +1509,6 @@ def test_lora_processors(self): # sample 2 and sample 3 should be different assert (sample2 - sample3).abs().max() > 1e-4 - def test_lora_save_load(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - old_sample = model(**inputs_dict).sample - - lora_attn_procs = create_lora_layers(model) - model.set_attn_processor(lora_attn_procs) - - with torch.no_grad(): - sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname, safe_serialization=False) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.to(torch_device) - new_model.load_attn_procs(tmpdirname) - - with torch.no_grad(): - new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - assert (sample - new_sample).abs().max() < 5e-4 - - # LoRA and no LoRA should NOT be the same - assert (sample - old_sample).abs().max() > 5e-4 - - def test_lora_save_load_safetensors(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - old_sample = model(**inputs_dict).sample - - lora_attn_procs = create_lora_layers(model) - model.set_attn_processor(lora_attn_procs) - - with torch.no_grad(): - sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname, safe_serialization=True) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.to(torch_device) - new_model.load_attn_procs(tmpdirname) - - with torch.no_grad(): - new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - assert (sample - new_sample).abs().max() < 1e-4 - - # LoRA and no LoRA should NOT be the same - assert (sample - old_sample).abs().max() > 1e-4 - - def test_lora_save_safetensors_load_torch(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - lora_attn_procs = create_lora_layers(model, mock_weights=False) - model.set_attn_processor(lora_attn_procs) - # Saving as torch, properly reloads with directly filename - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname, safe_serialization=True) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.to(torch_device) - new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.safetensors") - - def test_lora_save_torch_force_load_safetensors_error(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - lora_attn_procs = create_lora_layers(model, mock_weights=False) - model.set_attn_processor(lora_attn_procs) - # Saving as torch, properly reloads with directly filename - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname, safe_serialization=False) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.to(torch_device) - with self.assertRaises(IOError) as e: - new_model.load_attn_procs(tmpdirname, use_safetensors=True) - self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception)) - def test_lora_on_off(self, expected_max_diff=1e-3): # enable deterministic behavior for gradient checkpointing init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -1607,8 +1522,8 @@ def test_lora_on_off(self, expected_max_diff=1e-3): with torch.no_grad(): old_sample = model(**inputs_dict).sample - lora_attn_procs = create_lora_layers(model) - model.set_attn_processor(lora_attn_procs) + _, lora_params = create_unet_lora_layers(model) + model.load_attn_procs(lora_params) with torch.no_grad(): sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample @@ -1637,8 +1552,8 @@ def test_lora_xformers_on_off(self, expected_max_diff=6e-4): torch.manual_seed(0) model = self.model_class(**init_dict) model.to(torch_device) - lora_attn_procs = create_lora_layers(model) - model.set_attn_processor(lora_attn_procs) + _, lora_params = create_unet_lora_layers(model) + model.load_attn_procs(lora_params) # default with torch.no_grad(): @@ -1712,10 +1627,10 @@ def test_lora_processors(self): with torch.no_grad(): sample1 = model(**inputs_dict).sample - lora_attn_procs = create_lora_3d_layers(model) + unet_lora_params = create_3d_unet_lora_layers(model) # make sure we can set a list of attention processors - model.set_attn_processor(lora_attn_procs) + model.load_attn_procs(unet_lora_params) model.to(torch_device) # test that attn processors can be set to itself @@ -1732,172 +1647,6 @@ def test_lora_processors(self): # sample 2 and sample 3 should be different assert (sample2 - sample3).abs().max() > 3e-3 - def test_lora_save_load(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = 8 - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - old_sample = model(**inputs_dict).sample - - lora_attn_procs = create_lora_3d_layers(model) - model.set_attn_processor(lora_attn_procs) - - with torch.no_grad(): - sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname, safe_serialization=False) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.to(torch_device) - new_model.load_attn_procs(tmpdirname) - - with torch.no_grad(): - new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - assert (sample - new_sample).abs().max() < 5e-3 - - # LoRA and no LoRA should NOT be the same - assert (sample - old_sample).abs().max() > 1e-4 - - def test_lora_save_load_safetensors(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = 8 - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - old_sample = model(**inputs_dict).sample - - lora_attn_procs = create_lora_3d_layers(model) - model.set_attn_processor(lora_attn_procs) - - with torch.no_grad(): - sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname, safe_serialization=True) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.to(torch_device) - new_model.load_attn_procs(tmpdirname) - - with torch.no_grad(): - new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - assert (sample - new_sample).abs().max() < 3e-3 - - # LoRA and no LoRA should NOT be the same - assert (sample - old_sample).abs().max() > 1e-4 - - def test_lora_save_safetensors_load_torch(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = 8 - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - lora_attn_procs = create_lora_3d_layers(model, mock_weights=False) - model.set_attn_processor(lora_attn_procs) - # Saving as torch, properly reloads with directly filename - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.to(torch_device) - new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.safetensors") - - def test_lora_save_torch_force_load_safetensors_error(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = 8 - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - lora_attn_procs = create_lora_3d_layers(model, mock_weights=False) - model.set_attn_processor(lora_attn_procs) - # Saving as torch, properly reloads with directly filename - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname, safe_serialization=False) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.to(torch_device) - with self.assertRaises(IOError) as e: - new_model.load_attn_procs(tmpdirname, use_safetensors=True) - self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception)) - - def test_lora_on_off(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = 8 - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - old_sample = model(**inputs_dict).sample - - lora_attn_procs = create_lora_3d_layers(model) - model.set_attn_processor(lora_attn_procs) - - with torch.no_grad(): - sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample - - model.set_default_attn_processor() - - with torch.no_grad(): - new_sample = model(**inputs_dict).sample - - assert (sample - new_sample).abs().max() < 1e-4 - assert (sample - old_sample).abs().max() < 3e-3 - - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_lora_xformers_on_off(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = 4 - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - lora_attn_procs = create_lora_3d_layers(model) - model.set_attn_processor(lora_attn_procs) - - # default - with torch.no_grad(): - sample = model(**inputs_dict).sample - - model.enable_xformers_memory_efficient_attention() - on_sample = model(**inputs_dict).sample - - model.disable_xformers_memory_efficient_attention() - off_sample = model(**inputs_dict).sample - - assert (sample - on_sample).abs().max() < 1e-4 - assert (sample - off_sample).abs().max() < 1e-4 - @slow @deprecate_after_peft_backend