Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] IP Adapters (author @okotaku ) #5713

Merged
merged 161 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 127 commits
Commits
Show all changes
161 commits
Select commit Hold shift + click to select a range
08a1828
add ip-adapter
okotaku Sep 8, 2023
c4646f8
modularize.
sayakpaul Nov 2, 2023
f3755d4
add to inits.
sayakpaul Nov 2, 2023
5887af0
fix
sayakpaul Nov 2, 2023
f9aaa54
fix
sayakpaul Nov 2, 2023
a45292b
fix
sayakpaul Nov 2, 2023
023c2b7
fix
sayakpaul Nov 2, 2023
8fe3064
fix
sayakpaul Nov 2, 2023
dded7c4
fix
sayakpaul Nov 2, 2023
651302b
fix
sayakpaul Nov 2, 2023
f10eb25
fix
sayakpaul Nov 2, 2023
f051c9e
fix
sayakpaul Nov 2, 2023
6031383
device placement
sayakpaul Nov 2, 2023
95e38ac
device placement
sayakpaul Nov 2, 2023
3d69688
device placement fix.
sayakpaul Nov 2, 2023
cacee6d
Merge branch 'main' into feat/ip_adapter
okotaku Nov 3, 2023
351180f
fix import
okotaku Nov 3, 2023
2e83d6c
composable ip adapter module
sayakpaul Nov 3, 2023
1d64cb8
add image_encoder to sd as optional components
Nov 7, 2023
70fae5c
add image_prompt arg
Nov 7, 2023
3aaaa23
move image_projection to unet, refactor
Nov 8, 2023
2154d01
update comments
Nov 8, 2023
2807ee3
fix
sayakpaul Nov 8, 2023
bc52810
make image_encoder default to None.
sayakpaul Nov 8, 2023
eaf94bb
fully delegate the image encoding logic.
sayakpaul Nov 8, 2023
c22cd90
Merge branch 'main' into feat/ip_adapter
sayakpaul Nov 8, 2023
7cf7f70
debug
sayakpaul Nov 8, 2023
03e2961
fix
sayakpaul Nov 8, 2023
982a557
fix
sayakpaul Nov 8, 2023
6059099
fix:
sayakpaul Nov 8, 2023
c56503b
fix
sayakpaul Nov 8, 2023
59c933a
separate the loacder.
sayakpaul Nov 8, 2023
7ece033
circular import problem
sayakpaul Nov 8, 2023
4cb0432
circular imports.
sayakpaul Nov 8, 2023
17223d4
added_cond_kwargs not needed now.
sayakpaul Nov 8, 2023
8001d24
remove save_ip_adapter.
sayakpaul Nov 8, 2023
7887ba7
remove ip adapter pipeline from the face of the earth
sayakpaul Nov 8, 2023
46c668b
refactor __call__
sayakpaul Nov 8, 2023
6e28231
fix init.
sayakpaul Nov 8, 2023
3241c96
Merge branch 'main' into feat/ip_adapter
sayakpaul Nov 8, 2023
ef937be
remove none
sayakpaul Nov 8, 2023
7043443
image_encoder
sayakpaul Nov 8, 2023
d7e390f
module registration
sayakpaul Nov 8, 2023
86b0e4a
does defaulting to None work for modules?
sayakpaul Nov 8, 2023
d0cf0cc
unet
Nov 8, 2023
f2431b3
style
Nov 8, 2023
7fdbf86
fix a test
Nov 8, 2023
ba43e03
attemp to fix image_encoder none test
Nov 8, 2023
1d2b58b
sdxl
Nov 8, 2023
6c0106b
sd img2img
Nov 9, 2023
84bcbd6
inpaint
Nov 9, 2023
9b8b11a
another attemp to fix lora test
Nov 9, 2023
d662f6c
fix more tests
Nov 9, 2023
a77b1e5
more image_encoder: none
Nov 9, 2023
ecb2a5f
add to sdxl inpainting
sayakpaul Nov 9, 2023
c0042c1
add to sdxl image-to-image
sayakpaul Nov 9, 2023
44eb034
stylw
sayakpaul Nov 9, 2023
0f1e364
style
sayakpaul Nov 9, 2023
b2f7af0
feat: safetensors loading.
sayakpaul Nov 9, 2023
6af2112
fix: tests.
sayakpaul Nov 9, 2023
88efe67
fix more tests
sayakpaul Nov 9, 2023
5baa910
doc
Nov 9, 2023
abc1372
fix sdxl img2img + inpaint tests
Nov 9, 2023
5e60de5
add: integration test
sayakpaul Nov 9, 2023
95797b5
add: test_ prefix.
sayakpaul Nov 9, 2023
b04cdcf
subfolder
sayakpaul Nov 9, 2023
be73167
tests
sayakpaul Nov 9, 2023
fb401d4
add: image-to-image
sayakpaul Nov 9, 2023
86b4f09
fix bunch tests
Nov 9, 2023
426fdb3
Merge branch 'ip-adapter' of github.com:huggingface/diffusers into ip…
Nov 9, 2023
66f7023
fix: assertion values.
sayakpaul Nov 9, 2023
c904c63
add: inpainting
sayakpaul Nov 9, 2023
9085797
fix: assertion values for inpainting
sayakpaul Nov 9, 2023
ab060c4
fix: inpainting tests
sayakpaul Nov 9, 2023
0d7ef92
fix: more
sayakpaul Nov 9, 2023
b132f50
fix auto test
Nov 9, 2023
0dee7fa
Merge branch 'ip-adapter' of github.com:huggingface/diffusers into ip…
Nov 9, 2023
188f1d7
add: sdxl integration tests
sayakpaul Nov 9, 2023
36e7903
fix: assertion values for sdxl.
sayakpaul Nov 9, 2023
4f34e08
fix last one
Nov 9, 2023
2d2a7b1
Merge branch 'ip-adapter' of github.com:huggingface/diffusers into ip…
Nov 9, 2023
756534b
fix: assertion for inpainting
sayakpaul Nov 9, 2023
a17655b
fix tiny encoder
Nov 9, 2023
fcf60f3
Merge branch 'ip-adapter' of github.com:huggingface/diffusers into ip…
Nov 9, 2023
4d08930
make quality
Nov 9, 2023
cb451b0
fix
Nov 9, 2023
85f3959
Merge branch 'main' into ip-adapter
sayakpaul Nov 9, 2023
eed9900
add: fast test.
sayakpaul Nov 9, 2023
565c7c0
Merge branch 'main' into ip-adapter
sayakpaul Nov 9, 2023
82a7e4d
add sdxl docs
sayakpaul Nov 9, 2023
5c179b9
uP
patrickvonplaten Nov 9, 2023
eda593b
lcm add tests
patrickvonplaten Nov 9, 2023
5c838e4
Add co-author
patrickvonplaten Nov 9, 2023
7183b15
lcm add tests
patrickvonplaten Nov 9, 2023
9d7939f
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patrickvonplaten Nov 9, 2023
0b15eb1
uP
patrickvonplaten Nov 9, 2023
eec02db
fix
sayakpaul Nov 9, 2023
819ed61
be explicit about @okotaku
sayakpaul Nov 9, 2023
6e52db7
Apply suggestions from code review
patrickvonplaten Nov 9, 2023
d9d7672
remove mention
sayakpaul Nov 9, 2023
f35ce5b
Apply suggestions from code review
sayakpaul Nov 9, 2023
49234b1
uP
patrickvonplaten Nov 9, 2023
e9cdb69
uP
patrickvonplaten Nov 9, 2023
1cd4b23
Merge branch 'main' into ip-adapter
patrickvonplaten Nov 9, 2023
2ecbc44
uP
patrickvonplaten Nov 9, 2023
584138c
uP
patrickvonplaten Nov 9, 2023
82f0cc9
uP
patrickvonplaten Nov 9, 2023
9471dd9
more debug
patrickvonplaten Nov 9, 2023
7ecfcfe
Apply suggestions from code review
yiyixuxu Nov 9, 2023
e6c8934
Merge branch 'main' into ip-adapter
sayakpaul Nov 10, 2023
774f0dd
style
sayakpaul Nov 10, 2023
3ab4049
remove ipdb
sayakpaul Nov 10, 2023
319e003
Merge branch 'main' into ip-adapter
sayakpaul Nov 10, 2023
a106e83
style
sayakpaul Nov 10, 2023
90f9a58
debug
sayakpaul Nov 10, 2023
679bcf3
debug
sayakpaul Nov 10, 2023
d43f075
debug
sayakpaul Nov 10, 2023
be3d3e8
debug
sayakpaul Nov 10, 2023
105bd35
debug
sayakpaul Nov 10, 2023
1a28c32
debug
sayakpaul Nov 10, 2023
dc76816
debug
sayakpaul Nov 10, 2023
f06ba21
debug
sayakpaul Nov 10, 2023
af88728
more debug
patrickvonplaten Nov 10, 2023
9ece001
Merge branch 'ip-adapter' of https://github.com/huggingface/diffusers…
patrickvonplaten Nov 10, 2023
087417c
more debug
patrickvonplaten Nov 10, 2023
f4a04c0
more debug
patrickvonplaten Nov 10, 2023
5e4b53d
Apply suggestions from code review
patrickvonplaten Nov 10, 2023
286cb1a
add tests
patrickvonplaten Nov 10, 2023
54b3b21
Merge branch 'ip-adapter' of https://github.com/huggingface/diffusers…
patrickvonplaten Nov 10, 2023
9ff5f6b
Apply suggestions from code review
yiyixuxu Nov 10, 2023
e8f6a85
refactor load_ip_adapter: add capabiliity to load clip image encoder …
Nov 11, 2023
d50a19f
style
Nov 11, 2023
7e7f1dc
refacotr 2: unet._load_ip_adapter_weights
Nov 11, 2023
10b79b5
update doc
Nov 11, 2023
e00dcfe
controlnet
Nov 13, 2023
60049ca
animatediff
Nov 13, 2023
5641a64
fix tests + remove controlnet attn processor
Nov 13, 2023
9d94e20
sdxl
Nov 13, 2023
fed72fb
add doc
Nov 13, 2023
b4b32df
resolve merge
Nov 16, 2023
f46c2e4
fix circular import
Nov 16, 2023
3203eeb
fix
Nov 16, 2023
b40e94f
fix
Nov 16, 2023
fae2a05
fix
Nov 16, 2023
8fe9798
fix
Nov 16, 2023
f97a797
Update src/diffusers/pipelines/pipeline_utils.py
yiyixuxu Nov 16, 2023
c607878
Update src/diffusers/pipelines/pipeline_utils.py
yiyixuxu Nov 16, 2023
2c2c607
update tests
Nov 16, 2023
97c68eb
Merge branch 'ip-adapter' of github.com:huggingface/diffusers into ip…
Nov 16, 2023
dc1b7eb
style
Nov 16, 2023
dd67bcd
text_context_len -> num_tokens
Nov 18, 2023
55b6f5c
Merge branch 'main' into ip-adapter
yiyixuxu Nov 18, 2023
d4edc4e
fix
Nov 18, 2023
75022d0
support safetensors + make subfolder and weight_name required argument
Nov 18, 2023
aaba4d4
make aggresive tests
Nov 20, 2023
c0e9e5d
Merge branch 'main' into ip-adapter
yiyixuxu Nov 20, 2023
2fd1685
copies
Nov 20, 2023
b5029fb
Merge branch 'ip-adapter' of github.com:huggingface/diffusers into ip…
Nov 20, 2023
0162a45
add
Nov 20, 2023
304c790
fix
Nov 20, 2023
6645776
Merge branch 'main' into ip-adapter
patrickvonplaten Nov 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions docs/source/en/using-diffusers/loading_adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,144 @@ prompt = "a house by william eggleston, sunrays, beautiful, sunlight, sunrays, b
image = pipeline(prompt=prompt).images[0]
image
```

## IP-Adapter

[IP-Adapter](https://ip-adapter.github.io/) is an effective and lightweight adapter that adds image prompting capabilities to a diffusion model. This adapter works by decoupling the cross-attention layers of the image and text features. All the other model components are frozen and only the embedded image features in the UNet are trained. As a result, IP-Adapter files are typically only ~100MBs.

IP-Adapter works with most of our Stable Diffusion, Stable Diffusion XL (SDXL), ControlNet, T2I-Adapter, and any custom models finetuned from the same base models.

<Tip>

You can find official IP-Adapter checkpoints in [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter).

IP-Adapter was contributed by [okotaku](https://github.com/okotaku).

</Tip>

IP-Adapter relies on an image encoder to generate the image features, so let's load a [`~transformers.CLIPVisionModelWithProjection`] model and then pass it to a Stable Diffusion pipeline.

```py
from diffusers import AutoPipelineForText2Image, CLIPVisionModelWithProjection
import torch
from diffusers.utils import load_image

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=torch.float16,
).to("cuda")

pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, torch_dtype=torch.float16).to("cuda")
```

Now load the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) weights with the [`~loaders.IPAdapterMixin.load_ip_adapter`] method.

```py
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
```

IP-Adapter allows you to use both image and text to condition the image generation process. For example, let's use the bear image from the [Textual Inversion](#textual-inversion) section as the image prompt (`ip_adapter_image`) along with a text prompt to add "sunglasses". 😎

```py
pipeline.set_ip_adapter_scale(0.6)
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
generator = torch.Generator(device="cpu").manual_seed(33)
images = pipeline(
    prompt='best quality, high quality, wearing sunglasses',
    ip_adapter_image=image,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
    num_inference_steps=50,
    generator=generator,
).images
images[0]
```

<div class="flex justify-center">
    <img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip-bear.png" />
</div>

<Tip>

You can use the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method to adjust the text prompt and image prompt condition ratio.  If you're only using the image prompt, you should set the scale to `1.0`. You can lower the scale to get more generation diversity, but it'll be less aligned with the prompt.
`scale=0.5` can achieve good results in most cases when you use both text and image prompts.
</Tip>

IP-Adapter also works great with Image-to-Image and Inpainting pipelines. Here is an example of how you can use it with Image-to-Image.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since most of the code is pretty similar for the image-to-image example, what do you think about using it with the new hfoptions syntax to make it a bit cleaner (you can even throw in the inpaint example if you want)? For example:

<hfoptions id="tasks">
<hfoption id="text-to-image">

text + code example

</hfoption>
<hfoption id="image-to-image">

text + code example

</hfoption>
</hfoptions>

It'll look like:
Screenshot 2023-11-09 at 9 12 52 AM


```py
from diffusers import AutoPipelineForImage2Image
import torch
from diffusers.utils import load_image

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=torch.float16,
).to("cuda")

pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, torch_dtype=torch.float16).to("cuda")

image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg")
ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png")

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
generator = torch.Generator(device="cpu").manual_seed(33)
images = pipeline(
    prompt='best quality, high quality',
    image = image,
    ip_adapter_image=ip_image,
    num_inference_steps=50,
    generator=generator,
    strength=0.6,
).images
images[0]
```

IP-Adapters can also be used with [SDXL](../api/pipelines/stable_diffusion/stable_diffusion_xl.md), but you'll also need to load a [`~transformers.CLIPImageProcessor`] as your feature extractor and pass it to the pipeline.

```python
from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
import torch

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="sdxl_models/image_encoder",
torch_dtype=torch.float16,
).to("cuda")
feature_extractor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")

pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
image_encoder=image_encoder,
feature_extractor=feature_extractor,
torch_dtype=torch.float16
).to("cuda")

image = load_image("https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/watercolor_painting.jpeg")

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")

generator = torch.Generator(device="cpu").manual_seed(33)
image = pipeline(
prompt="best quality, high quality",
ip_adapter_image=image,
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=25,
generator=generator,
).images[0]
image.save("sdxl_t2i.png")
```

<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/watercolor_painting.jpeg"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">input image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/sdxl_t2i.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">adapted image</figcaption>
</div>
</div>
5 changes: 1 addition & 4 deletions examples/text_to_image/train_text_to_image_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,7 @@ def main():
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
data_dir=args.train_data_dir
args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
)
else:
data_files = {}
Expand Down
5 changes: 1 addition & 4 deletions examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,10 +765,7 @@ def load_model_hook(models, input_dir):
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
data_dir=args.train_data_dir
args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
)
else:
data_files = {}
Expand Down
177 changes: 177 additions & 0 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,21 @@
import requests
import safetensors
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download, model_info
from packaging import version
from torch import nn

from . import __version__
from .models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterControlNetAttnProcessor,
IPAdapterControlNetAttnProcessor2_0,
)
from .models.embeddings import ImageProjection
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from .utils import (
DIFFUSERS_CACHE,
Expand Down Expand Up @@ -3334,3 +3344,170 @@ def _remove_text_encoder_monkey_patch(self):
else:
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)


class IPAdapterMixin:
"""Mixin for handling IP Adapters."""

def set_ip_adapter(self):
unet = self.unet
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_processor_class = (
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_procs[name] = attn_processor_class()
else:
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
).to(dtype=unet.dtype, device=unet.device)

unet.set_attn_processor(attn_procs)

if hasattr(self, "controlnet"):
attn_processor_class = (
IPAdapterControlNetAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else IPAdapterControlNetAttnProcessor
)
self.pipeline.controlnet.set_attn_processor(attn_processor_class())

def load_ip_adapter(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we factor out the part where things are loaded into unet and add this as a method to the unet - for now as a private method maybe.

So we should add a _load_ip_adaptive_weights function to the UNet.

def load_ip_adapter(....):
      # 1. load state dicts
      # 2. load CLIP Image encoder and feature extractor
      # 3. load adaptable image proj and cross-attn weigths
       # 4. call unet._load_ip_adaptive_weights(....)
       # pass unet, image encoder, feature extractor to pipeline

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the rationale?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh sounds great!

so we don't have to pass the image_encoder to pipe.from_pretrained() any more - we can just automatically pick it up from the ip-adapter weights folder here. we can add a optional path feature extractor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Nov 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load CLIP Image encoder and feature extractor

I added code to do this! pretty happy about it. Now that if the image_encoder is not added to the pipeline yet and the weight folder come with a image encoder , it can automatically loaded here. But user still has the option to manually load it and pass it to the pipeline. I updated the doc to reflect this.

I don't think user ever need to worry abut loading feature_extractor - if it is not already a components in pipeline, we can always add it here with default config (that's what's used in original repo). So I removed the mention from the doc to keep it simple.

  1. load adaptable image proj and cross-attn weigths
    # 4. call unet._load_ip_adaptive_weights(....)

I created a unet._load_ip_adaptive_weights(state_dict) that does both 3 and 4 - let me know if this is something you have in mind. Happy to reiterate if it is not!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rationale here is simply to seperate concerns - the unet should be responsible for loading the trainable ip projection layers. It's not the pipeline's responsibility. The pipeline should only care about loading the image encoder and feature extractor and then offload the other responsability to the unet. Advantages:

  • 1.) It's cleaner
  • 2.) It's easier to test (just need to test the unet class now)
  • 3.) It's easier to extend to new use cases

self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:

- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).

cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
"""
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")

self.set_ip_adapter()

# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)

user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}

if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = safetensors.torch.load_file(model_file, device="cpu")
else:
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path_or_dict

keys = list(state_dict.keys())
if keys != ["image_proj", "ip_adapter"]:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")

# Handle image projection layers.
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4

image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
)
image_projection.to(dtype=self.unet.dtype, device=self.unet.device)

diffusers_state_dict = {}

diffusers_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
}
)

image_projection.load_state_dict(diffusers_state_dict)

self.unet.encoder_hid_proj = image_projection.to(device=self.unet.device, dtype=self.unet.dtype)
self.unet.config.encoder_hid_dim_type = "image_proj"
self.unet.config.encoder_hid_dim = clip_embeddings_dim

# Handle IP-Adapter cross-attention layers.
ip_layers = torch.nn.ModuleList(
[
module if isinstance(module, nn.Module) else nn.Identity()
for module in self.unet.attn_processors.values()
]
)
ip_layers.load_state_dict(state_dict["ip_adapter"])

def set_ip_adapter_scale(self, scale):
for attn_processor in self.unet.attn_processors.values():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
attn_processor.scale = scale
Loading