-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] IP Adapters (author @okotaku ) #5713
Changes from 127 commits
08a1828
c4646f8
f3755d4
5887af0
f9aaa54
a45292b
023c2b7
8fe3064
dded7c4
651302b
f10eb25
f051c9e
6031383
95e38ac
3d69688
cacee6d
351180f
2e83d6c
1d64cb8
70fae5c
3aaaa23
2154d01
2807ee3
bc52810
eaf94bb
c22cd90
7cf7f70
03e2961
982a557
6059099
c56503b
59c933a
7ece033
4cb0432
17223d4
8001d24
7887ba7
46c668b
6e28231
3241c96
ef937be
7043443
d7e390f
86b0e4a
d0cf0cc
f2431b3
7fdbf86
ba43e03
1d2b58b
6c0106b
84bcbd6
9b8b11a
d662f6c
a77b1e5
ecb2a5f
c0042c1
44eb034
0f1e364
b2f7af0
6af2112
88efe67
5baa910
abc1372
5e60de5
95797b5
b04cdcf
be73167
fb401d4
86b4f09
426fdb3
66f7023
c904c63
9085797
ab060c4
0d7ef92
b132f50
0dee7fa
188f1d7
36e7903
4f34e08
2d2a7b1
756534b
a17655b
fcf60f3
4d08930
cb451b0
85f3959
eed9900
565c7c0
82a7e4d
5c179b9
eda593b
5c838e4
7183b15
9d7939f
0b15eb1
eec02db
819ed61
6e52db7
d9d7672
f35ce5b
49234b1
e9cdb69
1cd4b23
2ecbc44
584138c
82f0cc9
9471dd9
7ecfcfe
e6c8934
774f0dd
3ab4049
319e003
a106e83
90f9a58
679bcf3
d43f075
be3d3e8
105bd35
1a28c32
dc76816
f06ba21
af88728
9ece001
087417c
f4a04c0
5e4b53d
286cb1a
54b3b21
9ff5f6b
e8f6a85
d50a19f
7e7f1dc
10b79b5
e00dcfe
60049ca
5641a64
9d94e20
fed72fb
b4b32df
f46c2e4
3203eeb
b40e94f
fae2a05
8fe9798
f97a797
c607878
2c2c607
97c68eb
dc1b7eb
dd67bcd
55b6f5c
d4edc4e
75022d0
aaba4d4
c0e9e5d
2fd1685
b5029fb
0162a45
304c790
6645776
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,11 +22,21 @@ | |
import requests | ||
import safetensors | ||
import torch | ||
import torch.nn.functional as F | ||
from huggingface_hub import hf_hub_download, model_info | ||
from packaging import version | ||
from torch import nn | ||
|
||
from . import __version__ | ||
from .models.attention_processor import ( | ||
AttnProcessor, | ||
AttnProcessor2_0, | ||
IPAdapterAttnProcessor, | ||
IPAdapterAttnProcessor2_0, | ||
IPAdapterControlNetAttnProcessor, | ||
IPAdapterControlNetAttnProcessor2_0, | ||
) | ||
from .models.embeddings import ImageProjection | ||
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta | ||
from .utils import ( | ||
DIFFUSERS_CACHE, | ||
|
@@ -3334,3 +3344,170 @@ def _remove_text_encoder_monkey_patch(self): | |
else: | ||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) | ||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) | ||
|
||
|
||
class IPAdapterMixin: | ||
"""Mixin for handling IP Adapters.""" | ||
|
||
def set_ip_adapter(self): | ||
unet = self.unet | ||
attn_procs = {} | ||
for name in unet.attn_processors.keys(): | ||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | ||
if name.startswith("mid_block"): | ||
hidden_size = unet.config.block_out_channels[-1] | ||
elif name.startswith("up_blocks"): | ||
block_id = int(name[len("up_blocks.")]) | ||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | ||
elif name.startswith("down_blocks"): | ||
block_id = int(name[len("down_blocks.")]) | ||
hidden_size = unet.config.block_out_channels[block_id] | ||
if cross_attention_dim is None: | ||
attn_processor_class = ( | ||
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor | ||
) | ||
attn_procs[name] = attn_processor_class() | ||
else: | ||
attn_processor_class = ( | ||
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor | ||
) | ||
attn_procs[name] = attn_processor_class( | ||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 | ||
).to(dtype=unet.dtype, device=unet.device) | ||
|
||
unet.set_attn_processor(attn_procs) | ||
|
||
if hasattr(self, "controlnet"): | ||
attn_processor_class = ( | ||
IPAdapterControlNetAttnProcessor2_0 | ||
if hasattr(F, "scaled_dot_product_attention") | ||
else IPAdapterControlNetAttnProcessor | ||
) | ||
self.pipeline.controlnet.set_attn_processor(attn_processor_class()) | ||
|
||
def load_ip_adapter( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we factor out the part where things are loaded into unet and add this as a method to the unet - for now as a private method maybe. So we should add a def load_ip_adapter(....):
# 1. load state dicts
# 2. load CLIP Image encoder and feature extractor
# 3. load adaptable image proj and cross-attn weigths
# 4. call unet._load_ip_adaptive_weights(....)
# pass unet, image encoder, feature extractor to pipeline There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain the rationale? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ohh sounds great! so we don't have to pass the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or maybe not needed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I added code to do this! pretty happy about it. Now that if the image_encoder is not added to the pipeline yet and the weight folder come with a image encoder , it can automatically loaded here. But user still has the option to manually load it and pass it to the pipeline. I updated the doc to reflect this. I don't think user ever need to worry abut loading feature_extractor - if it is not already a components in pipeline, we can always add it here with default config (that's what's used in original repo). So I removed the mention from the doc to keep it simple.
I created a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rationale here is simply to seperate concerns - the unet should be responsible for loading the trainable ip projection layers. It's not the pipeline's responsibility. The pipeline should only care about loading the image encoder and feature extractor and then offload the other responsability to the unet. Advantages:
|
||
self, | ||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | ||
**kwargs, | ||
): | ||
""" | ||
Parameters: | ||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): | ||
Can be either: | ||
|
||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on | ||
the Hub. | ||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved | ||
with [`ModelMixin.save_pretrained`]. | ||
- A [torch state | ||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). | ||
|
||
cache_dir (`Union[str, os.PathLike]`, *optional*): | ||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache | ||
is not used. | ||
force_download (`bool`, *optional*, defaults to `False`): | ||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the | ||
cached versions if they exist. | ||
resume_download (`bool`, *optional*, defaults to `False`): | ||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any | ||
incompletely downloaded files are deleted. | ||
proxies (`Dict[str, str]`, *optional*): | ||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', | ||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | ||
local_files_only (`bool`, *optional*, defaults to `False`): | ||
Whether to only load local model weights and configuration files or not. If set to `True`, the model | ||
won't be downloaded from the Hub. | ||
use_auth_token (`str` or *bool*, *optional*): | ||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from | ||
`diffusers-cli login` (stored in `~/.huggingface`) is used. | ||
revision (`str`, *optional*, defaults to `"main"`): | ||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier | ||
allowed by Git. | ||
subfolder (`str`, *optional*, defaults to `""`): | ||
The subfolder location of a model file within a larger model repository on the Hub or locally. | ||
""" | ||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: | ||
raise ValueError("`image_encoder` cannot be None when using IP Adapters.") | ||
|
||
self.set_ip_adapter() | ||
|
||
# Load the main state dict first. | ||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) | ||
force_download = kwargs.pop("force_download", False) | ||
resume_download = kwargs.pop("resume_download", False) | ||
proxies = kwargs.pop("proxies", None) | ||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) | ||
use_auth_token = kwargs.pop("use_auth_token", None) | ||
revision = kwargs.pop("revision", None) | ||
subfolder = kwargs.pop("subfolder", None) | ||
weight_name = kwargs.pop("weight_name", None) | ||
|
||
user_agent = { | ||
"file_type": "attn_procs_weights", | ||
"framework": "pytorch", | ||
} | ||
|
||
if not isinstance(pretrained_model_name_or_path_or_dict, dict): | ||
model_file = _get_model_file( | ||
pretrained_model_name_or_path_or_dict, | ||
weights_name=weight_name, | ||
cache_dir=cache_dir, | ||
force_download=force_download, | ||
resume_download=resume_download, | ||
proxies=proxies, | ||
local_files_only=local_files_only, | ||
use_auth_token=use_auth_token, | ||
revision=revision, | ||
subfolder=subfolder, | ||
user_agent=user_agent, | ||
) | ||
if weight_name.endswith(".safetensors"): | ||
state_dict = safetensors.torch.load_file(model_file, device="cpu") | ||
else: | ||
state_dict = torch.load(model_file, map_location="cpu") | ||
else: | ||
state_dict = pretrained_model_name_or_path_or_dict | ||
|
||
keys = list(state_dict.keys()) | ||
if keys != ["image_proj", "ip_adapter"]: | ||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") | ||
|
||
# Handle image projection layers. | ||
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] | ||
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4 | ||
|
||
image_projection = ImageProjection( | ||
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4 | ||
) | ||
image_projection.to(dtype=self.unet.dtype, device=self.unet.device) | ||
|
||
diffusers_state_dict = {} | ||
|
||
diffusers_state_dict.update( | ||
{ | ||
"image_embeds.weight": state_dict["image_proj"]["proj.weight"], | ||
"image_embeds.bias": state_dict["image_proj"]["proj.bias"], | ||
"norm.weight": state_dict["image_proj"]["norm.weight"], | ||
"norm.bias": state_dict["image_proj"]["norm.bias"], | ||
} | ||
) | ||
|
||
image_projection.load_state_dict(diffusers_state_dict) | ||
|
||
self.unet.encoder_hid_proj = image_projection.to(device=self.unet.device, dtype=self.unet.dtype) | ||
self.unet.config.encoder_hid_dim_type = "image_proj" | ||
self.unet.config.encoder_hid_dim = clip_embeddings_dim | ||
|
||
# Handle IP-Adapter cross-attention layers. | ||
ip_layers = torch.nn.ModuleList( | ||
[ | ||
module if isinstance(module, nn.Module) else nn.Identity() | ||
for module in self.unet.attn_processors.values() | ||
] | ||
) | ||
ip_layers.load_state_dict(state_dict["ip_adapter"]) | ||
|
||
def set_ip_adapter_scale(self, scale): | ||
for attn_processor in self.unet.attn_processors.values(): | ||
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): | ||
attn_processor.scale = scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since most of the code is pretty similar for the image-to-image example, what do you think about using it with the new
hfoptions
syntax to make it a bit cleaner (you can even throw in the inpaint example if you want)? For example:It'll look like:
