diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md index e73e042bd4d5..c14b38a9dd89 100644 --- a/docs/source/en/using-diffusers/loading_adapters.md +++ b/docs/source/en/using-diffusers/loading_adapters.md @@ -307,3 +307,331 @@ prompt = "a house by william eggleston, sunrays, beautiful, sunlight, sunrays, b image = pipeline(prompt=prompt).images[0] image ``` + +## IP-Adapter + +[IP-Adapter](https://ip-adapter.github.io/) is an effective and lightweight adapter that adds image prompting capabilities to a diffusion model. This adapter works by decoupling the cross-attention layers of the image and text features. All the other model components are frozen and only the embedded image features in the UNet are trained. As a result, IP-Adapter files are typically only ~100MBs. + +IP-Adapter works with most of our pipelines, including Stable Diffusion, Stable Diffusion XL (SDXL), ControlNet, T2I-Adapter, AnimateDiff. And you can use any custom models finetuned from the same base models. It also works with LCM-Lora out of box. + + + + +You can find official IP-Adapter checkpoints in [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter). + +IP-Adapter was contributed by [okotaku](https://github.com/okotaku). + + + +Let's first create a Stable Diffusion Pipeline. + +```py +from diffusers import AutoPipelineForText2Image +import torch +from diffusers.utils import load_image + + +pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") +``` + +Now load the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) weights with the [`~loaders.IPAdapterMixin.load_ip_adapter`] method. + +```py +pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") +``` + + +IP-Adapter relies on an image encoder to generate the image features, if your IP-Adapter weights folder contains a "image_encoder" subfolder, the image encoder will be automatically loaded and registered to the pipeline. Otherwise you can so load a [`~transformers.CLIPVisionModelWithProjection`] model and pass it to a Stable Diffusion pipeline when you create it. + +```py +from diffusers import AutoPipelineForText2Image, CLIPVisionModelWithProjection +import torch + +image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "h94/IP-Adapter", + subfolder="models/image_encoder", + torch_dtype=torch.float16, +).to("cuda") + +pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, torch_dtype=torch.float16).to("cuda") +``` + + +IP-Adapter allows you to use both image and text to condition the image generation process. For example, let's use the bear image from the [Textual Inversion](#textual-inversion) section as the image prompt (`ip_adapter_image`) along with a text prompt to add "sunglasses". 😎 + +```py +pipeline.set_ip_adapter_scale(0.6) +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png") +generator = torch.Generator(device="cpu").manual_seed(33) +images = pipeline( +    prompt='best quality, high quality, wearing sunglasses', +    ip_adapter_image=image, +    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", +    num_inference_steps=50, +    generator=generator, +).images +images[0] +``` + +
+    +
+ + + +You can use the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method to adjust the text prompt and image prompt condition ratio.  If you're only using the image prompt, you should set the scale to `1.0`. You can lower the scale to get more generation diversity, but it'll be less aligned with the prompt. +`scale=0.5` can achieve good results in most cases when you use both text and image prompts. + + +IP-Adapter also works great with Image-to-Image and Inpainting pipelines. See below examples of how you can use it with Image-to-Image and Inpaint. + + + + +```py +from diffusers import AutoPipelineForImage2Image +import torch +from diffusers.utils import load_image + +pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") + +image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg") +ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png") + +pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") +generator = torch.Generator(device="cpu").manual_seed(33) +images = pipeline( +    prompt='best quality, high quality', +    image = image, +    ip_adapter_image=ip_image, +    num_inference_steps=50, +    generator=generator, +    strength=0.6, +).images +images[0] +``` + + + + +```py +from diffusers import AutoPipelineForInpaint +import torch +from diffusers.utils import load_image + +pipeline = AutoPipelineForInpaint.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float).to("cuda") + +image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png") +mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/mask.png") +ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/girl.png") + +image = image.resize((512, 768)) +mask = mask.resize((512, 768)) + +pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") + +generator = torch.Generator(device="cpu").manual_seed(33) +images = pipeline( + prompt='best quality, high quality', + image = image, + mask_image = mask, + ip_adapter_image=ip_image, + negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", + num_inference_steps=50, + generator=generator, + strength=0.5, +).images +images[0] +``` + + + + +IP-Adapters can also be used with [SDXL](../api/pipelines/stable_diffusion/stable_diffusion_xl.md) + +```python +from diffusers import AutoPipelineForText2Image +from diffusers.utils import load_image +import torch + +pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16 +).to("cuda") + +image = load_image("https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/watercolor_painting.jpeg") + +pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") + +generator = torch.Generator(device="cpu").manual_seed(33) +image = pipeline( + prompt="best quality, high quality", + ip_adapter_image=image, + negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", + num_inference_steps=25, + generator=generator, +).images[0] +image.save("sdxl_t2i.png") +``` + +
+
+ +
input image
+
+
+ +
adapted image
+
+
+ + +### LCM-Lora + +You can use IP-Adapter with LCM-Lora to achieve "instant fine-tune" with custom images. Note that you need to load IP-Adapter weights before loading the LCM-Lora weights. + +```py +from diffusers import DiffusionPipeline, LCMScheduler +import torch +from diffusers.utils import load_image + +model_id = "sd-dreambooth-library/herge-style" +lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5" + +pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) + +pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") +pipe.load_lora_weights(lcm_lora_id) +pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) +pipe.enable_model_cpu_offload() + +prompt = "best quality, high quality" +image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png") +images = pipe( + prompt=prompt, + ip_adapter_image=image, + num_inference_steps=4, + guidance_scale=1, +).images[0] +``` + +### Other pipelines + +IP-Adapter is compatible with any pipeline that (1) uses a text prompt and (2) uses Stable Diffusion or Stable Diffusion XL checkpoint. To use IP-Adapter with a different pipeline, all you need to do is to run `load_ip_adapter()` method after you create the pipeline, and then pass your image to the pipeline as `ip_adapter_image` + + + +🤗 Diffusers currently only supports using IP-Adapter with some of the most popular pipelines, feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require integrating IP-adapters with a pipeline that does not support it yet! + + + +You can find below examples on how to use IP-Adapter with ControlNet and AnimateDiff. + + + + +``` +from diffusers import StableDiffusionControlNetPipeline, ControlNetModel +import torch +from diffusers.utils import load_image + +controlnet_model_path = "lllyasviel/control_v11f1p_sd15_depth" +controlnet = ControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.float16) + +pipeline = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16) +pipeline.to("cuda") + +image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/statue.png") +depth_map = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/depth.png") + +pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") + +generator = torch.Generator(device="cpu").manual_seed(33) +images = pipeline( + prompt='best quality, high quality', + image=depth_map, + ip_adapter_image=image, + negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", + num_inference_steps=50, + generator=generator, +).images +images[0] +``` +
+
+ +
input image
+
+
+ +
adapted image
+
+
+ +
+ + +```py +# animate diff + ip adapter +import torch +from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler +from diffusers.utils import export_to_gif, load_image + +# Load the motion adapter +adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16) +# load SD 1.5 based finetuned model +model_id = "Lykon/DreamShaper" +pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16) + +# scheduler +scheduler = DDIMScheduler( + clip_sample=False, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="linear", + timestep_spacing="trailing", + steps_offset=1 +) +pipe.scheduler = scheduler + +# enable memory savings +pipe.enable_vae_slicing() +pipe.enable_model_cpu_offload() + +# load ip_adapter +pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") + +# load motion adapters +pipe.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out") +pipe.load_lora_weights("guoyww/animatediff-motion-lora-tilt-up", adapter_name="tilt-up") +pipe.load_lora_weights("guoyww/animatediff-motion-lora-pan-left", adapter_name="pan-left") + +seed = 42 +image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png") +images = [image] * 3 +prompts = ["best quality, high quality"] * 3 +negative_prompt = "bad quality, worst quality" +adapter_weights = [[0.75, 0.0, 0.0], [0.0, 0.0, 0.75], [0.0, 0.75, 0.75]] + +# generate +output_frames = [] +for prompt, image, adapter_weight in zip(prompts, images, adapter_weights): + pipe.set_adapters(["zoom-out", "tilt-up", "pan-left"], adapter_weights=adapter_weight) + output = pipe( + prompt= prompt, + num_frames=16, + guidance_scale=7.5, + num_inference_steps=30, + ip_adapter_image = image, + generator=torch.Generator("cpu").manual_seed(seed), + ) + frames = output.frames[0] + output_frames.extend(frames) + +export_to_gif(output_frames, "test_out_animation.gif") +``` + + +
+ diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 14fd985f69e4..684736856029 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -62,6 +62,7 @@ def text_encoder_attn_modules(text_encoder): _import_structure["single_file"].extend(["FromSingleFileMixin"]) _import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] + _import_structure["ip_adapter"] = ["IPAdapterMixin"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -72,6 +73,7 @@ def text_encoder_attn_modules(text_encoder): from .utils import AttnProcsLayers if is_transformers_available(): + from .ip_adapter import IPAdapterMixin from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py new file mode 100644 index 000000000000..32c558554be2 --- /dev/null +++ b/src/diffusers/loaders/ip_adapter.py @@ -0,0 +1,157 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Dict, Union + +import torch +from safetensors import safe_open + +from ..utils import ( + DIFFUSERS_CACHE, + HF_HUB_OFFLINE, + _get_model_file, + is_transformers_available, + logging, +) + + +if is_transformers_available(): + from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + ) + + from ..models.attention_processor import ( + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, + ) + +logger = logging.get_logger(__name__) + + +class IPAdapterMixin: + """Mixin for handling IP Adapters.""" + + def load_ip_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + subfolder: str, + weight_name: str, + **kwargs, + ): + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + """ + + # Load the main state dict first. + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if keys != ["image_proj", "ip_adapter"]: + raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") + + # load CLIP image encoer here if it has not been registered to the pipeline yet + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + pretrained_model_name_or_path_or_dict, + subfolder=os.path.join(subfolder, "image_encoder"), + ).to(self.device, dtype=self.dtype) + self.image_encoder = image_encoder + else: + raise ValueError("`image_encoder` cannot be None when using IP Adapters.") + + # create feature extractor if it has not been registered to the pipeline yet + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: + self.feature_extractor = CLIPImageProcessor() + + # load ip-adapter into unet + self.unet._load_ip_adapter_weights(state_dict) + + def set_ip_adapter_scale(self, scale): + for attn_processor in self.unet.attn_processors.values(): + if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): + attn_processor.scale = scale diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 9555ac9e7d8b..6c805672c9cd 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -18,8 +18,10 @@ import safetensors import torch +import torch.nn.functional as F from torch import nn +from ..models.embeddings import ImageProjection from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import ( DIFFUSERS_CACHE, @@ -662,4 +664,72 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): if hasattr(self, "peft_config"): self.peft_config.pop(adapter_name, None) + def _load_ip_adapter_weights(self, state_dict): + from ..models.attention_processor import ( + AttnProcessor, + AttnProcessor2_0, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, + ) + + # set ip-adapter cross-attention processors & load state_dict + attn_procs = {} + key_id = 1 + for name in self.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = self.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(self.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = self.config.block_out_channels[block_id] + if cross_attention_dim is None or "motion_modules" in name: + attn_processor_class = ( + AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor + ) + attn_procs[name] = attn_processor_class() + else: + attn_processor_class = ( + IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor + ) + attn_procs[name] = attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 + ).to(dtype=self.dtype, device=self.device) + + value_dict = {} + for k, w in attn_procs[name].state_dict().items(): + value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]}) + + attn_procs[name].load_state_dict(value_dict) + key_id += 2 + + self.set_attn_processor(attn_procs) + + # create image projection layers. + clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] + cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4 + + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4 + ) + image_projection.to(dtype=self.dtype, device=self.device) + + # load image projection layer weights + image_proj_state_dict = {} + image_proj_state_dict.update( + { + "image_embeds.weight": state_dict["image_proj"]["proj.weight"], + "image_embeds.bias": state_dict["image_proj"]["proj.bias"], + "norm.weight": state_dict["image_proj"]["norm.weight"], + "norm.bias": state_dict["image_proj"]["norm.bias"], + } + ) + + image_projection.load_state_dict(image_proj_state_dict) + + self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) + self.config.encoder_hid_dim_type = "ip_image_proj" + delete_adapter_layers diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 1234dbd2d5ce..6b86ba66db37 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1975,6 +1975,250 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **k return attn.processor(attn, hidden_states, *args, **kwargs) +class IPAdapterAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + num_tokens (`int`, defaults to 4): + The context length of the image features. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + self.scale = scale + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + scale=1.0, + ): + if scale != 1.0: + logger.warning("`scale` of IPAttnProcessor should be set with `set_ip_adapter_scale`.") + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + # split hidden states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAdapterAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + num_tokens (`int`, defaults to 4): + The context length of the image features. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + self.scale = scale + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + scale=1.0, + ): + if scale != 1.0: + logger.warning("`scale` of IPAttnProcessor should be set by `set_ip_adapter_scale`.") + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + # split hidden states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + LORA_ATTENTION_PROCESSORS = ( LoRAAttnProcessor, LoRAAttnProcessor2_0, @@ -1998,6 +2242,8 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **k LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, ) AttentionProcessor = Union[ diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index f248b243f376..dd91d8007229 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -1022,6 +1022,15 @@ def forward( ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) + encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) + # 2. pre-process sample = self.conv_in(sample) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 7be1a59114ef..0bbc573e7df1 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -208,6 +208,8 @@ def __init__( motion_max_seq_length: int = 32, motion_num_attention_heads: int = 8, use_motion_mid_block: int = True, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, ): super().__init__() @@ -248,6 +250,9 @@ def __init__( act_fn=act_fn, ) + if encoder_hid_dim_type is None: + self.encoder_hid_proj = None + # class embedding self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -684,6 +689,7 @@ def forward( timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, @@ -767,6 +773,16 @@ def forward( emb = self.time_embedding(t_emb, timestep_cond) emb = emb.repeat_interleave(repeats=num_frames, dim=0) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) + encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) # 2. pre-process diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 9f51c084d5f8..843e3b8b9410 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -17,11 +17,11 @@ import torch from packaging import version -from transformers import CLIPImageProcessor, XLMRobertaTokenizer +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer from ...configuration_utils import FrozenDict -from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -74,7 +74,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker -class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): +class AltDiffusionPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin +): r""" Pipeline for text-to-image generation using Alt Diffusion. @@ -86,6 +88,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -108,7 +111,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -121,6 +124,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() @@ -197,6 +201,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -444,6 +449,19 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None @@ -652,6 +670,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -698,6 +717,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -797,12 +817,18 @@ def __call__( lora_scale=lora_scale, clip_skip=self.clip_skip, ) + # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps @@ -823,7 +849,10 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 6.5 Optionally get Guidance Scale Embedding + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) @@ -847,6 +876,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 129794f7fbbd..b196ac4d3f69 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -19,11 +19,11 @@ import PIL.Image import torch from packaging import version -from transformers import CLIPImageProcessor, XLMRobertaTokenizer +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -111,7 +111,7 @@ def preprocess(image): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker class AltDiffusionImg2ImgPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image-to-image generation using Alt Diffusion. @@ -124,6 +124,7 @@ class AltDiffusionImg2ImgPipeline( - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -146,7 +147,7 @@ class AltDiffusionImg2ImgPipeline( """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -159,6 +160,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() @@ -235,6 +237,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -453,6 +456,19 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None @@ -705,6 +721,7 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -754,6 +771,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -846,6 +864,11 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Preprocess image image = self.image_processor.preprocess(image) @@ -868,7 +891,10 @@ def __call__( # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7.5 Optionally get Guidance Scale Embedding + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 7.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) @@ -892,6 +918,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 71adb8408c88..28dc220545dc 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -18,10 +18,10 @@ import numpy as np import torch -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unet_motion_model import MotionAdapter @@ -77,7 +77,7 @@ class AnimateDiffPipelineOutput(BaseOutput): frames: Union[torch.Tensor, np.ndarray] -class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): +class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): r""" Pipeline for text-to-video generation. @@ -101,6 +101,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo """ model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder"] def __init__( self, @@ -117,6 +118,8 @@ def __init__( EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() unet = UNetMotionModel.from_unet2d(unet, motion_adapter) @@ -128,6 +131,8 @@ def __init__( unet=unet, motion_adapter=motion_adapter, scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -314,6 +319,20 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents @@ -512,6 +531,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -558,6 +578,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or `np.array`. @@ -629,6 +650,11 @@ def __call__( if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt) + if do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps @@ -649,6 +675,8 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 7 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -664,6 +692,7 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, ).sample # perform guidance diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 4f625304fdf9..41e5e75f68e5 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -20,10 +20,10 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -92,7 +92,7 @@ class StableDiffusionControlNetPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin ): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. @@ -102,6 +102,7 @@ class StableDiffusionControlNetPipeline( The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -128,7 +129,7 @@ class StableDiffusionControlNetPipeline( """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -142,6 +143,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() @@ -174,6 +176,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) @@ -430,6 +433,20 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -803,6 +820,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -860,6 +878,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -997,6 +1016,11 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( @@ -1063,7 +1087,10 @@ def __call__( # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7.1 Create tensor stating which controlnets to keep + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 7.2 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ @@ -1131,6 +1158,7 @@ def __call__( cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index c1efd8aaa397..4696781dce0c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -20,12 +20,23 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) from diffusers.utils.import_utils import is_invisible_watermark_available from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -104,7 +115,11 @@ class StableDiffusionXLControlNetPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. @@ -149,7 +164,14 @@ class StableDiffusionXLControlNetPipeline( # leave controlnet out on purpose because it iterates with unet model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( @@ -164,6 +186,8 @@ def __init__( scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() @@ -179,6 +203,8 @@ def __init__( unet=unet, controlnet=controlnet, scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) @@ -462,6 +488,20 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -879,6 +919,7 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -959,6 +1000,7 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -1100,7 +1142,7 @@ def __call__( ) guess_mode = guess_mode or global_pool_conditions - # 3. Encode input prompt + # 3.1 Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) @@ -1125,6 +1167,12 @@ def __call__( clip_skip=self.clip_skip, ) + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( @@ -1299,6 +1347,9 @@ def __call__( down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds + # predict the noise residual noise_pred = self.unet( latent_model_input, diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py index cbb55b504c54..2e25a40295b4 100644 --- a/src/diffusers/pipelines/pipeline_flax_utils.py +++ b/src/diffusers/pipelines/pipeline_flax_utils.py @@ -538,12 +538,13 @@ def load_module(name, value): model = pipeline_class(**init_kwargs, dtype=dtype) return model, params - @staticmethod - def _get_signature_keys(obj): + @classmethod + def _get_signature_keys(cls, obj): parameters = inspect.signature(obj.__init__).parameters required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) expected_modules = set(required_parameters.keys()) - {"self"} + return expected_modules, optional_parameters @property diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 5fa1938983d5..0208ade020bd 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -557,7 +557,7 @@ def register_modules(self, **kwargs): for name, module in kwargs.items(): # retrieve library - if module is None: + if module is None or isinstance(module, (tuple, list)) and module[0] is None: register_dict = {name: (None, None)} else: # register the config from the original module, not the dynamo compiled one @@ -1906,12 +1906,19 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: " above." ) from model_info_call_error - @staticmethod - def _get_signature_keys(obj): + @classmethod + def _get_signature_keys(cls, obj): parameters = inspect.signature(obj.__init__).parameters required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) expected_modules = set(required_parameters.keys()) - {"self"} + + optional_names = list(optional_parameters) + for name in optional_names: + if name in cls._optional_components: + expected_modules.add(name) + optional_parameters.remove(name) + return expected_modules, optional_parameters @property diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 5af5a42256f3..a05abe00f2b1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -17,11 +17,11 @@ import torch from packaging import version -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict -from ...image_processor import VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -70,7 +70,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin): +class StableDiffusionPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin +): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -82,6 +84,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -104,7 +107,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -117,6 +120,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() @@ -193,6 +197,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -440,6 +445,19 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None @@ -649,6 +667,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -695,6 +714,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -794,12 +814,18 @@ def __call__( lora_scale=lora_scale, clip_skip=self.clip_skip, ) + # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps @@ -820,7 +846,10 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 6.5 Optionally get Guidance Scale Embedding + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) @@ -844,6 +873,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index c75afb0789a4..029cd2b04839 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -19,11 +19,11 @@ import PIL.Image import torch from packaging import version -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -106,7 +106,7 @@ def preprocess(image): class StableDiffusionImg2ImgPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image-to-image generation using Stable Diffusion. @@ -119,6 +119,7 @@ class StableDiffusionImg2ImgPipeline( - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -141,7 +142,7 @@ class StableDiffusionImg2ImgPipeline( """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -154,6 +155,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() @@ -230,6 +232,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -450,6 +453,20 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -708,6 +725,7 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -757,6 +775,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -849,6 +868,11 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Preprocess image image = self.image_processor.preprocess(image) @@ -871,7 +895,10 @@ def __call__( # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7.5 Optionally get Guidance Scale Embedding + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 7.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) @@ -895,6 +922,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index e4a25e181e42..09e50c60a807 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -19,11 +19,11 @@ import PIL.Image import torch from packaging import version -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -170,7 +170,7 @@ def retrieve_latents(encoder_output, generator): class StableDiffusionInpaintPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image inpainting using Stable Diffusion. @@ -182,6 +182,7 @@ class StableDiffusionInpaintPipeline( - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]): @@ -204,7 +205,7 @@ class StableDiffusionInpaintPipeline( """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "mask", "masked_image_latents"] @@ -217,6 +218,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() @@ -298,6 +300,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -521,6 +524,20 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -837,6 +854,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -902,6 +920,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -1029,6 +1048,11 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps( @@ -1117,7 +1141,10 @@ def __call__( # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 9.5 Optionally get Guidance Scale Embedding + # 9.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 9.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) @@ -1146,6 +1173,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index c50a036a88f8..e32791693012 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -16,11 +16,18 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, + IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) @@ -94,7 +101,11 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): class StableDiffusionXLPipeline( - DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin + DiffusionPipeline, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -142,7 +153,14 @@ class StableDiffusionXLPipeline( """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] _callback_tensor_inputs = [ "latents", "prompt_embeds", @@ -162,6 +180,8 @@ def __init__( tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, ): @@ -175,6 +195,8 @@ def __init__( tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) @@ -456,6 +478,20 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -718,6 +754,7 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -802,6 +839,7 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -1000,6 +1038,12 @@ def __call__( add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + image_embeds = image_embeds.to(device) + # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -1037,6 +1081,8 @@ def __call__( # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( latent_model_input, t, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 56f1a5196cf0..d40a037e67fe 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -17,10 +17,21 @@ import PIL.Image import torch -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -104,7 +115,11 @@ def retrieve_latents(encoder_output, generator): class StableDiffusionXLImg2ImgPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin + DiffusionPipeline, + TextualInversionLoaderMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -155,7 +170,14 @@ class StableDiffusionXLImg2ImgPipeline( """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] _callback_tensor_inputs = [ "latents", "prompt_embeds", @@ -175,6 +197,8 @@ def __init__( tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, @@ -188,6 +212,8 @@ def __init__( tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) @@ -665,6 +691,20 @@ def prepare_latents( return latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + def _get_add_time_ids( self, original_size, @@ -850,6 +890,7 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -943,6 +984,7 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -1162,6 +1204,12 @@ def denoising_value_valid(dnv): add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + image_embeds = image_embeds.to(device) + # 9. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -1205,6 +1253,8 @@ def denoising_value_valid(dnv): # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( latent_model_input, t, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index d618ea4c2a71..3a9d068d60f3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -18,10 +18,21 @@ import numpy as np import PIL.Image import torch -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -249,7 +260,11 @@ def retrieve_latents(encoder_output, generator): class StableDiffusionXLInpaintPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -301,7 +316,14 @@ class StableDiffusionXLInpaintPipeline( model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] _callback_tensor_inputs = [ "latents", "prompt_embeds", @@ -323,6 +345,8 @@ def __init__( tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, @@ -336,6 +360,8 @@ def __init__( tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) @@ -386,6 +412,20 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, @@ -1074,6 +1114,7 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -1172,6 +1213,7 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -1471,6 +1513,12 @@ def denoising_value_valid(dnv): add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + image_embeds = image_embeds.to(device) + # 11. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -1517,6 +1565,8 @@ def denoising_value_valid(dnv): # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( latent_model_input, t, diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 91f8c2c3dc03..a940cec5e46a 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1221,6 +1221,15 @@ def forward( ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) + encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) + # 2. pre-process sample = self.conv_in(sample) diff --git a/tests/lora/test_lora_layers_old_backend.py b/tests/lora/test_lora_layers_old_backend.py index 285c5e864a04..19505a1d906d 100644 --- a/tests/lora/test_lora_layers_old_backend.py +++ b/tests/lora/test_lora_layers_old_backend.py @@ -246,6 +246,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } lora_components = { "unet_lora_layers": unet_lora_layers, @@ -757,6 +758,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components @@ -866,6 +868,8 @@ def get_dummy_components(self): "text_encoder_2": text_encoder_2, "tokenizer": tokenizer, "tokenizer_2": tokenizer_2, + "image_encoder": None, + "feature_extractor": None, } lora_components = { "unet_lora_layers": unet_lora_layers, diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index c290850a10b6..48ae5d197273 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -140,6 +140,8 @@ def get_dummy_components(self, scheduler_cls=None): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "image_encoder": None, + "feature_extractor": None, } else: pipeline_components = { @@ -150,6 +152,7 @@ def get_dummy_components(self, scheduler_cls=None): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } lora_components = { "unet_lora_layers": unet_lora_layers, diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 0db336a88029..06bf2685560d 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -24,7 +24,8 @@ from pytest import mark from diffusers import UNet2DConditionModel -from diffusers.models.attention_processor import CustomDiffusionAttnProcessor +from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor +from diffusers.models.embeddings import ImageProjection from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( @@ -45,6 +46,57 @@ enable_full_determinism() +def create_ip_adapter_state_dict(model): + # "ip_adapter" (cross-attention weights) + ip_cross_attn_state_dict = {} + key_id = 1 + + for name in model.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + if cross_attention_dim is not None: + sd = IPAdapterAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"], + } + ) + + key_id += 2 + + # "image_proj" (ImageProjection layer weights) + cross_attention_dim = model.config["cross_attention_dim"] + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, num_image_text_embeds=4 + ) + + ip_image_projection_state_dict = {} + sd = image_projection.state_dict() + ip_image_projection_state_dict.update( + { + "proj.weight": sd["image_embeds.weight"], + "proj.bias": sd["image_embeds.bias"], + "norm.weight": sd["norm.weight"], + "norm.bias": sd["norm.bias"], + } + ) + + del sd + ip_state_dict = {} + ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) + return ip_state_dict + + def create_custom_diffusion_layers(model, mock_weights: bool = True): train_kv = True train_q_out = True @@ -622,6 +674,56 @@ def test_asymmetrical_unet(self): # Check if input and output shapes are the same self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def test_ip_adapter(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + # forward pass without ip-adapter + with torch.no_grad(): + sample1 = model(**inputs_dict).sample + + # update inputs_dict for ip-adapter + batch_size = inputs_dict["encoder_hidden_states"].shape[0] + image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device) + inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds} + + # make ip_adapter_1 and ip_adapter_2 + ip_adapter_1 = create_ip_adapter_state_dict(model) + + image_proj_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items()} + cross_attn_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items()} + ip_adapter_2 = {} + ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2}) + + # forward pass ip_adapter_1 + model._load_ip_adapter_weights(ip_adapter_1) + assert model.config.encoder_hid_dim_type == "ip_image_proj" + assert model.encoder_hid_proj is not None + assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in ( + "IPAdapterAttnProcessor", + "IPAdapterAttnProcessor2_0", + ) + with torch.no_grad(): + sample2 = model(**inputs_dict).sample + + # forward pass with ip_adapter_2 + model._load_ip_adapter_weights(ip_adapter_2) + with torch.no_grad(): + sample3 = model(**inputs_dict).sample + + # forward pass with ip_adapter_1 again + model._load_ip_adapter_weights(ip_adapter_1) + with torch.no_grad(): + sample4 = model(**inputs_dict).sample + + assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4) + assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4) + assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4) + @slow class UNet2DConditionModelIntegrationTests(unittest.TestCase): diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion.py b/tests/pipelines/altdiffusion/test_alt_diffusion.py index 5befe60cf6d9..b4a2847bb84d 100644 --- a/tests/pipelines/altdiffusion/test_alt_diffusion.py +++ b/tests/pipelines/altdiffusion/test_alt_diffusion.py @@ -117,6 +117,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py index 57001f7bea52..3fd1a90172ca 100644 --- a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py +++ b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py @@ -141,6 +141,7 @@ def test_stable_diffusion_img2img_default_case(self): tokenizer=tokenizer, safety_checker=None, feature_extractor=self.dummy_extractor, + image_encoder=None, ) alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=True) alt_pipe = alt_pipe.to(device) @@ -205,6 +206,7 @@ def test_stable_diffusion_img2img_fp16(self): tokenizer=tokenizer, safety_checker=None, feature_extractor=self.dummy_extractor, + image_encoder=None, ) alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False) alt_pipe = alt_pipe.to(torch_device) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 3c9390f2d1b6..5cd0a45c7406 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -99,6 +99,8 @@ def get_dummy_components(self): "motion_adapter": motion_adapter, "text_encoder": text_encoder, "tokenizer": tokenizer, + "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index 2d8c8869c23c..1cf52bfeebe2 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -183,6 +183,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components @@ -341,6 +342,7 @@ def init_weights(m): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components @@ -518,6 +520,7 @@ def init_weights(m): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 36ddee36eb52..88d2df1ec0f8 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -146,6 +146,8 @@ def get_dummy_components(self, time_cond_proj_dim=None): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "feature_extractor": None, + "image_encoder": None, } return components @@ -471,6 +473,8 @@ def init_weights(m): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "feature_extractor": None, + "image_encoder": None, } return components @@ -656,6 +660,8 @@ def init_weights(m): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py new file mode 100644 index 000000000000..57eb49013c1f --- /dev/null +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -0,0 +1,221 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, +) + +from diffusers import ( + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLPipeline, +) +from diffusers.utils import load_image +from diffusers.utils.testing_utils import ( + enable_full_determinism, + require_torch_gpu, + slow, + torch_device, +) + + +enable_full_determinism() + + +class IPAdapterNightlyTestsMixin(unittest.TestCase): + dtype = torch.float16 + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_image_encoder(self, repo_id, subfolder): + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + repo_id, subfolder=subfolder, torch_dtype=self.dtype + ).to(torch_device) + return image_encoder + + def get_image_processor(self, repo_id): + image_processor = CLIPImageProcessor.from_pretrained(repo_id) + return image_processor + + def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_sdxl=False): + image = load_image( + "https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png" + ) + if for_sdxl: + image = image.resize((1024, 1024)) + + input_kwargs = { + "prompt": "best quality, high quality", + "negative_prompt": "monochrome, lowres, bad anatomy, worst quality, low quality", + "num_inference_steps": 5, + "generator": torch.Generator(device="cpu").manual_seed(33), + "ip_adapter_image": image, + "output_type": "np", + } + if for_image_to_image: + image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg") + ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png") + + if for_sdxl: + image = image.resize((1024, 1024)) + ip_image = ip_image.resize((1024, 1024)) + + input_kwargs.update({"image": image, "ip_adapter_image": ip_image}) + + elif for_inpainting: + image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png") + mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/mask.png") + ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/girl.png") + + if for_sdxl: + image = image.resize((1024, 1024)) + mask = mask.resize((1024, 1024)) + ip_image = ip_image.resize((1024, 1024)) + + input_kwargs.update({"image": image, "mask_image": mask, "ip_adapter_image": ip_image}) + + return input_kwargs + + +@slow +@require_torch_gpu +class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin): + def test_text_to_image(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + pipeline = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") + + inputs = self.get_dummy_inputs() + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.8047, 0.8774, 0.9248, 0.9155, 0.9814, 1.0, 0.9678, 1.0, 1.0]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + def test_image_to_image(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + pipeline = StableDiffusionImg2ImgPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") + + inputs = self.get_dummy_inputs(for_image_to_image=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.2307, 0.2341, 0.2305, 0.24, 0.2268, 0.25, 0.2322, 0.2588, 0.2935]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + def test_inpainting(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + pipeline = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") + + inputs = self.get_dummy_inputs(for_inpainting=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.2705, 0.2395, 0.2209, 0.2312, 0.2102, 0.2104, 0.2178, 0.2065, 0.1997]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + +@slow +@require_torch_gpu +class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin): + def test_text_to_image_sdxl(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder") + feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + + pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") + + inputs = self.get_dummy_inputs() + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.0968, 0.0959, 0.0852, 0.0912, 0.0948, 0.093, 0.0893, 0.0932, 0.0923]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + def test_image_to_image_sdxl(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder") + feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + + pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") + + inputs = self.get_dummy_inputs(for_image_to_image=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.0653, 0.0704, 0.0725, 0.0741, 0.0702, 0.0647, 0.0782, 0.0799, 0.0752]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + def test_inpainting_sdxl(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder") + feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + + pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") + + inputs = self.get_dummy_inputs(for_inpainting=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + image_slice.tolist() + + expected_slice = np.array([0.1418, 0.1493, 0.1428, 0.146, 0.1491, 0.1501, 0.1473, 0.1501, 0.1516]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 53284b80952c..15c1c4fe6671 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -163,6 +163,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 12c6d8cf63d3..1a482b38e2ee 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -150,6 +150,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 287b2eac4d75..cbe4fb2a0ddf 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -153,6 +153,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components @@ -353,6 +354,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 4414d1ec5075..ed295f792f99 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -123,6 +123,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py index 92e8857610ea..41b9f83914a6 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py @@ -108,6 +108,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py index e2d476dec502..09034789c61c 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py @@ -127,6 +127,7 @@ def test_stable_diffusion_v_pred_ddim(self): tokenizer=tokenizer, safety_checker=None, feature_extractor=None, + image_encoder=None, requires_safety_checker=False, ) sd_pipe = sd_pipe.to(device) @@ -176,6 +177,7 @@ def test_stable_diffusion_v_pred_k_euler(self): tokenizer=tokenizer, safety_checker=None, feature_extractor=None, + image_encoder=None, requires_safety_checker=False, ) sd_pipe = sd_pipe.to(device) @@ -236,6 +238,7 @@ def test_stable_diffusion_v_pred_fp16(self): tokenizer=tokenizer, safety_checker=None, feature_extractor=None, + image_encoder=None, requires_safety_checker=False, ) sd_pipe = sd_pipe.to(torch_device) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 95fbb658fe5e..8957ebbef5ab 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -131,6 +131,8 @@ def get_dummy_components(self, time_cond_proj_dim=None): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "image_encoder": None, + "feature_extractor": None, } return components diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index 55779e5f060d..444f12ecfa9d 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -18,7 +18,15 @@ import numpy as np import torch -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) from diffusers import ( AutoencoderKL, @@ -95,6 +103,31 @@ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim latent_channels=4, sample_size=128, ) + torch.manual_seed(0) + image_encoder_config = CLIPVisionConfig( + hidden_size=32, + image_size=224, + projection_dim=32, + intermediate_size=37, + num_attention_heads=4, + num_channels=3, + num_hidden_layers=5, + patch_size=14, + ) + + image_encoder = CLIPVisionModelWithProjection(image_encoder_config) + + feature_extractor = CLIPImageProcessor( + crop_size=224, + do_center_crop=True, + do_normalize=True, + do_resize=True, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + resample=3, + size=224, + ) + torch.manual_seed(0) text_encoder_config = CLIPTextConfig( bos_token_id=0, @@ -125,6 +158,8 @@ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, "requires_aesthetics_score": True, + "image_encoder": image_encoder, + "feature_extractor": feature_extractor, } return components @@ -458,6 +493,8 @@ def get_dummy_components(self): "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, "requires_aesthetics_score": True, + "image_encoder": None, + "feature_extractor": None, } return components diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py index 54c750f997b6..7f7a0d81e5a2 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py @@ -20,7 +20,15 @@ import numpy as np import torch from PIL import Image -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) from diffusers import ( AutoencoderKL, @@ -120,6 +128,31 @@ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + torch.manual_seed(0) + image_encoder_config = CLIPVisionConfig( + hidden_size=32, + image_size=224, + projection_dim=32, + intermediate_size=37, + num_attention_heads=4, + num_channels=3, + num_hidden_layers=5, + patch_size=14, + ) + + image_encoder = CLIPVisionModelWithProjection(image_encoder_config) + + feature_extractor = CLIPImageProcessor( + crop_size=224, + do_center_crop=True, + do_normalize=True, + do_resize=True, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + resample=3, + size=224, + ) + components = { "unet": unet, "scheduler": scheduler, @@ -128,6 +161,8 @@ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim "tokenizer": tokenizer if not skip_first_text_encoder else None, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "image_encoder": image_encoder, + "feature_extractor": feature_extractor, "requires_aesthetics_score": True, } return components diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 42c90e47af80..d812ce0ccb95 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1136,8 +1136,8 @@ def test_stable_diffusion_components(self): safety_checker=None, feature_extractor=self.dummy_extractor, ).to(torch_device) - img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device) - text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device) + img2img = StableDiffusionImg2ImgPipeline(**inpaint.components, image_encoder=None).to(torch_device) + text2img = StableDiffusionPipeline(**inpaint.components, image_encoder=None).to(torch_device) prompt = "A painting of a squirrel eating a burger" @@ -1276,6 +1276,29 @@ def test_set_component_to_none(self): assert out_image.shape == (1, 64, 64, 3) assert np.abs(out_image - out_image_2).max() < 1e-3 + def test_optional_components_is_none(self): + unet = self.dummy_cond_unet() + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + items = { + "feature_extractor": self.dummy_extractor, + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": bert, + "tokenizer": tokenizer, + "safety_checker": None, + # we don't add an image encoder + } + + pipeline = StableDiffusionPipeline(**items) + + assert sorted(pipeline.components.keys()) == sorted(["image_encoder"] + list(items.keys())) + assert pipeline.image_encoder is None + def test_set_scheduler_consistency(self): unet = self.dummy_cond_unet() pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")