Skip to content

Commit

Permalink
Add IP Adapter support for SDXL ControlNet Inpaint pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
atesgoral committed Jan 16, 2024
1 parent 1040dfd commit 45ac6c4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import PIL.Image
import torch
import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
Expand Down Expand Up @@ -140,7 +140,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):


class StableDiffusionXLControlNetInpaintPipeline(
DiffusionPipeline, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin
DiffusionPipeline, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion XL.
Expand All @@ -152,6 +152,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
Expand Down Expand Up @@ -179,7 +180,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
"""

model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "image_encoder"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]

def __init__(
Expand All @@ -192,6 +193,7 @@ def __init__(
unet: UNet2DConditionModel,
controlnet: ControlNetModel,
scheduler: KarrasDiffusionSchedulers,
image_encoder: CLIPVisionModelWithProjection = None,
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
Expand All @@ -210,6 +212,7 @@ def __init__(
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
image_encoder=image_encoder,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
Expand Down Expand Up @@ -497,6 +500,22 @@ def encode_prompt(

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype

if not isinstance(image, torch.Tensor):
image = self.feature_extractor(
image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(
num_images_per_prompt, dim=0)

uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
Expand Down Expand Up @@ -1079,6 +1098,7 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
Expand Down Expand Up @@ -1167,6 +1187,7 @@ def __call__(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
Expand Down Expand Up @@ -1348,6 +1369,12 @@ def __call__(
clip_skip=self.clip_skip,
)

if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])

# 4. set timesteps
def denoising_value_valid(dnv):
return isinstance(denoising_end, float) and 0 < dnv < 1
Expand Down Expand Up @@ -1557,6 +1584,10 @@ def denoising_value_valid(dnv):

added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}

# Add image embeds for IP-Adapter
if ip_adapter_image:
added_cond_kwargs["image_embeds"] = image_embeds

# controlnet(s) inference
if guess_mode and self.do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
Expand Down
1 change: 1 addition & 0 deletions tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def get_dummy_components(self):
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"image_encoder": None,
}
return components

Expand Down

0 comments on commit 45ac6c4

Please sign in to comment.