Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA refactor 2] feat: add support for load_lora(). #5958

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 102 additions & 2 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
USE_PEFT_BACKEND,
_get_model_file,
delete_adapter_layers,
get_adapter_name,
is_accelerate_available,
logging,
set_adapter_layers,
Expand All @@ -49,6 +50,10 @@
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"

PEFT_WEIGHT_NAME = "adapter_model.bin"
PEFT_WEIGHT_NAME_SAFE = "adapter_model.safetensors"
PEFT_CONFIG_NAME = "adapter_config.json"
Comment on lines +53 to +55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
PEFT_WEIGHT_NAME = "adapter_model.bin"
PEFT_WEIGHT_NAME_SAFE = "adapter_model.safetensors"
PEFT_CONFIG_NAME = "adapter_config.json"
PEFT_WEIGHT_NAME = "adapter_model.bin"
PEFT_WEIGHT_NAME_SAFE = "adapter_model.safetensors"
PEFT_CONFIG_NAME = "adapter_config.json"

@BenjaminBossan @pacman100 @younesbelkada @apolinario - We need to discuss this a bit. I'm all in favor of supporting adapter_config.json mid-term, but we have some peculiarities here:

  • Currently LoRAs serialization format is a single file that contains weights for both the Unet and the text_encoder. However, it's also not uncommon that one only trains the unet in which case the single file only contains the unet
  • I'd say pretty much always text encoder and unet LoRAs are trained together, but I do see use cases where one might want to load two different LoRA files (one for it's text encoder, one for it's unet)

=> As a consequence the main "loading" function of LoRA is part of the pipeline:

pipeline = DiffusionPipeline.from_pretrained("...")
pipeline.load_lora_weights("...")  # this internally then dispatches to loading the text encoder specific loras to the text encoder and the unet specific loras to the unet

However, it's also very important that we allow loading LoRAs just from the unet or the text encoder:

pipeline.text_encoder.load_lora   # <- here we call Transformers' lora loading function
pipeline.unet.load_lora  # <- this needs to be implemented well in Diffusers

So the question here is, what format and serialization should we use for PEFT that is both:

  • Easy to use for training (meaning easy to use just on the unet object when no pipeline is loaded)
  • Easy to use for inference (when loaded via pipeline.load_lora_weights)
  • Easy to use to mix different LoRAs

Can we allow adapter.json to have a config for either just a text encoder, just a unet, both text encoder and unet
Can adapter.safetensors have both formats?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I grasped the problem completely, but the 2nd part (separate LoRAs for text encoder and unet) would be the easier one, right? We would just point to different folders, each with their own adapter_config.json etc.

For the combined adapter, we would need to find a new solution, it could be some kind of convention (sub-folder names) or we would need to have a single adapter that takes care of both, but that seems unwieldy and more error prone.

Another use case I'd like to bring up, not sure if it's relevant at this moment: Having a pipeline load multiple different LoRA adapters (or even other type of adapters).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but the 2nd part (separate LoRAs for text encoder and unet) would be the easier one, right?

Yes, it'd def be cleaner as a serialization format, but it's very much not the convention right now. At the moment, all LoRAs always only have a single safetensors file. So if we introduce a new 4 file convention (text_encoder's adapter_config.json, text_encoder's adapter_model.safetensors unet's adapter_config.json and unet's adapter_model.safetensors) it might be quite confusing for the community. I'm not too happy about creating a new folder structure as this wasn't super well perceived in the first place.

Copy link
Contributor

@pacman100 pacman100 Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the file format with subfolders for unet/text-encoder each having a adapter_model.safetensors and adapter_config.json. This is exactly mimicing the folder structure of full finetuning/pretrained SD/SDXL in Diffusers just that the config/weights are only for adapters. So, that would be clearly following the standard persistence format of Diffusers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For portability, I had written a conversion script which would put both the unet and text-encoder weights and their config's in a single safetensors file. The file is this: https://github.com/pacman100/peft-dreambooth/blob/master/peft_utils/merge_peft_sd_ckpt_to_single_safetensors.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way, current paradigm of single safetensors for LoRA would be achieved for PEFT and users are happy. But, this single safetensors file won't work in Auto1111 as it requires changing the state dict keys which is also implemented here https://github.com/pacman100/peft-sd-webui-additional-networks/blob/main/scripts/peft_lora.py

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your thoughts. But the thing is the diffusion community very much prefers the single file-format at least for LoRAs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@pacman100 pacman100 Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing to keep in mind is that the utils need to generic. For example, A PR to add OFT in PEFT is close to being merged. It follows the standard format of subfolders with adapter+config files. The utils need to be generic to convert both LoRA or OFT into single safetensors format.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a utility to output a single file seems to be the best choice here without compromising too much.


CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"

Expand Down Expand Up @@ -345,7 +350,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
is_model_cpu_offload = False
is_sequential_cpu_offload = False

# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `lora_lora_weights_into_unet`
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `load_lora_weights_into_unet`
if not USE_PEFT_BACKEND:
if _pipeline is not None:
for _, component in _pipeline.components.items():
Expand Down Expand Up @@ -823,4 +828,99 @@ def _load_ip_adapter_weights(self, state_dict):
self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
self.config.encoder_hid_dim_type = "ip_image_proj"

delete_adapter_layers
def load_lora(self, pretrained_model_name_or_path: str, **kwargs):
r"""
Load LoRA checkpoints with PEFT.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `load_lora()`.")

from peft import PeftConfig, inject_adapter_in_model, set_peft_model_state_dict

cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
adapter_name = kwargs.pop("adapter_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False

if use_safetensors is None:
use_safetensors = True
allow_pickle = True

user_agent = {
"file_type": "load_lora_peft",
"framework": "pytorch",
}

# Load the state dict.
model_file = None
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or PEFT_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except IOError as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
pass

if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or PEFT_WEIGHT_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")

state_dict = {k.replace("base_model.model.", ""): v for k, v in state_dict.items()}

rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]

# Load the PEFT config.
lora_config = PeftConfig.from_pretrained(pretrained_model_name_or_path)

# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)

# Adapter injection
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)

if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
68 changes: 67 additions & 1 deletion tests/lora/test_lora_layers_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from accelerate.utils import release_memory

if is_peft_available():
from peft import LoraConfig
from peft import LoraConfig, get_peft_model
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import get_peft_model_state_dict

Expand Down Expand Up @@ -1393,6 +1393,72 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
}


@require_peft_backend
class UNet2DConditionModelLoRATests(unittest.TestCase):
def get_dummy_components(self):
unet_kwargs = {
"block_out_channels": (32, 64),
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
"cross_attention_dim": 32,
"attention_head_dim": 8,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 2,
"sample_size": 32,
}
torch.manual_seed(0)
unet = UNet2DConditionModel(**unet_kwargs)
unet_lora_config = LoraConfig(
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
)
return unet, unet_lora_config

def get_dummy_inputs(self):
batch_size = 2
num_channels = 4
sizes = (32, 32)

noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)

return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}

def test_inference(self):
unet, unet_lora_config = self.get_dummy_components()
inputs = self.get_dummy_inputs()
outputs = unet(**inputs).sample

unet = get_peft_model(unet, unet_lora_config)
outputs_with_lora = unet(**inputs).sample

with tempfile.TemporaryDirectory() as tmpdirname:
unet.save_pretrained(tmpdirname)
# `peft` stable release doesn't default to safetensors yet.
# we run checks both with `peft` main and stable release.
try:
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "adapter_model.safetensors")))
except Exception:
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "adapter_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "adapter_config.json")))

unet, _ = self.get_dummy_components()
unet.load_lora(tmpdirname)

has_peft_layer = any(isinstance(unet_module, BaseTunerLayer) for unet_module in unet.modules())
assert has_peft_layer, "No PEFT layer found"

outputs_with_lora_loaded = unet(**inputs).sample

assert not torch.allclose(
outputs, outputs_with_lora, atol=1e-3, rtol=1e-3
), "LoRA layers should affect the outputs."
assert torch.allclose(
outputs_with_lora_loaded, outputs_with_lora, atol=1e-3, rtol=1e-3
), "Loaded LoRA layers should match the outputs."


@slow
@require_torch_gpu
class LoraIntegrationTests(unittest.TestCase):
Expand Down