From efaf085f6c59dff753113cf0bbd84ad3551b52b6 Mon Sep 17 00:00:00 2001 From: Ates Goral Date: Tue, 16 Jan 2024 18:16:33 -0500 Subject: [PATCH 1/2] Add IP Adapter support for SDXL ControlNet Inpaint pipeline --- .../pipeline_controlnet_inpaint_sd_xl.py | 41 +++++++++++++++++-- .../test_controlnet_inpaint_sdxl.py | 1 + 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 76b97b48f97c..f0a848b66720 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -19,10 +19,15 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -140,7 +145,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): class StableDiffusionXLControlNetInpaintPipeline( - DiffusionPipeline, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -152,6 +157,7 @@ class StableDiffusionXLControlNetInpaintPipeline( - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -179,7 +185,7 @@ class StableDiffusionXLControlNetInpaintPipeline( """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "image_encoder"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( @@ -192,6 +198,7 @@ def __init__( unet: UNet2DConditionModel, controlnet: ControlNetModel, scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, @@ -210,6 +217,7 @@ def __init__( unet=unet, controlnet=controlnet, scheduler=scheduler, + image_encoder=image_encoder, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) @@ -497,6 +505,20 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -1079,6 +1101,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", @@ -1167,6 +1190,7 @@ def __call__( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. @@ -1348,6 +1372,11 @@ def __call__( clip_skip=self.clip_skip, ) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. set timesteps def denoising_value_valid(dnv): return isinstance(denoising_end, float) and 0 < dnv < 1 @@ -1557,6 +1586,10 @@ def denoising_value_valid(dnv): added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + # Add image embeds for IP-Adapter + if ip_adapter_image: + added_cond_kwargs["image_embeds"] = image_embeds + # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py index 0ac8996fe0ef..e22b5d72491f 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py @@ -135,6 +135,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "image_encoder": None, } return components From 04d58f8a3431c09843cb4ffe520761e619658ba2 Mon Sep 17 00:00:00 2001 From: Ates Goral Date: Tue, 16 Jan 2024 20:01:01 -0500 Subject: [PATCH 2/2] Add feature_extractor --- .../pipeline_controlnet_inpaint_sd_xl.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index f0a848b66720..51a2bba0f192 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -19,7 +19,13 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( @@ -178,6 +184,8 @@ class StableDiffusionXLControlNetInpaintPipeline( tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of @@ -185,7 +193,14 @@ class StableDiffusionXLControlNetInpaintPipeline( """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "image_encoder"] + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( @@ -198,6 +213,7 @@ def __init__( unet: UNet2DConditionModel, controlnet: ControlNetModel, scheduler: KarrasDiffusionSchedulers, + feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, @@ -217,6 +233,7 @@ def __init__( unet=unet, controlnet=controlnet, scheduler=scheduler, + feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)