Skip to content

Commit

Permalink
Support of ip-adapter to the StableDiffusionControlNetInpaintPipeline (
Browse files Browse the repository at this point in the history
…#5887)

* Change pipeline_controlnet_inpaint.py to add ip-adapter support. Changes are similar to those in pipeline_controlnet

* Change tests for the StableDiffusionControlNetInpaintPipeline by adding image_encoder: None

* Update src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
  • Loading branch information
juancopi81 and yiyixuxu authored Nov 29, 2023
1 parent 895c4b7 commit 9f7b2cf
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
40 changes: 34 additions & 6 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
import PIL.Image
import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
Expand Down Expand Up @@ -241,7 +241,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False


class StableDiffusionControlNetInpaintPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
):
r"""
Pipeline for image inpainting using Stable Diffusion with ControlNet guidance.
Expand All @@ -251,6 +251,7 @@ class StableDiffusionControlNetInpaintPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
<Tip>
Expand Down Expand Up @@ -288,7 +289,7 @@ class StableDiffusionControlNetInpaintPipeline(
"""

model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]

Expand All @@ -302,6 +303,7 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
Expand Down Expand Up @@ -334,6 +336,7 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
Expand Down Expand Up @@ -593,6 +596,20 @@ def encode_prompt(

return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype

if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)

uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
Expand Down Expand Up @@ -1053,6 +1070,7 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -1131,6 +1149,7 @@ def __call__(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Expand Down Expand Up @@ -1264,6 +1283,11 @@ def __call__(
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])

# 4. Prepare image
if isinstance(controlnet, ControlNetModel):
control_image = self.prepare_control_image(
Expand Down Expand Up @@ -1299,7 +1323,7 @@ def __call__(
else:
assert False

# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
# 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)

Expand Down Expand Up @@ -1360,7 +1384,10 @@ def __call__(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

# 7.1 Create tensor stating which controlnets to keep
# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None

# 7.2 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
Expand Down Expand Up @@ -1423,6 +1450,7 @@ def __call__(
cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]

Expand Down
3 changes: 3 additions & 0 deletions tests/pipelines/controlnet/test_controlnet_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def get_dummy_components(self):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components

Expand Down Expand Up @@ -248,6 +249,7 @@ def get_dummy_components(self):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components

Expand Down Expand Up @@ -342,6 +344,7 @@ def init_weights(m):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components

Expand Down

0 comments on commit 9f7b2cf

Please sign in to comment.