Skip to content

Commit e3b1a85

Browse files
authored
Merge pull request AUTOMATIC1111#4 from Klace/img2img_integration
Img2img integration
2 parents 2c1bb46 + c88108f commit e3b1a85

7 files changed

+94
-27
lines changed

aes_scores.json

+1
Large diffs are not rendered by default.

exif_data.json

+1
Large diffs are not rendered by default.

modules/img2img.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
7676
processed_image.save(os.path.join(output_dir, filename))
7777

7878

79-
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
79+
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
8080
override_settings = create_override_settings_dict(override_settings_texts)
8181

8282
is_batch = mode == 5
@@ -132,6 +132,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
132132
n_iter=n_iter,
133133
steps=steps,
134134
cfg_scale=cfg_scale,
135+
image_cfg_scale=image_cfg_scale,
135136
width=width,
136137
height=height,
137138
restore_faces=restore_faces,

modules/processing.py

+46-14
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import cv2
1212
from skimage import exposure
1313
from typing import Any, Dict, List, Optional
14+
from torch import autocast
15+
1416

1517
import modules.sd_hijack
1618
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
@@ -186,7 +188,11 @@ def depth2img_image_conditioning(self, source_image):
186188
return conditioning
187189

188190
def edit_image_conditioning(self, source_image):
189-
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
191+
#source_image = 2 * torch.tensor(np.array(source_image)).float() / 255 - 1
192+
#source_image = rearrange(source_image, "h w c -> 1 c h w").to(shared.device)
193+
#source_image = rearrange(source_image, "h w c -> 1 c h w").to(shared.device)
194+
#conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
195+
conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
190196

191197
return conditioning_image
192198

@@ -450,11 +456,14 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
450456
"Size": f"{p.width}x{p.height}",
451457
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
452458
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
459+
"Batch size": (None if p.batch_size < 2 else p.batch_size),
460+
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
453461
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
454462
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
455463
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
456464
"Denoising strength": getattr(p, 'denoising_strength', None),
457465
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
466+
"Eta": (None),
458467
"Clip skip": None if clip_skip <= 1 else clip_skip,
459468
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
460469
}
@@ -622,15 +631,17 @@ def get_conds_with_caching(function, required_prompts, steps, cache):
622631
if p.n_iter > 1:
623632
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
624633

634+
print(f"c = {c} and uc = {uc}")
625635
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
626636
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
627637

628638
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
629-
for x in x_samples_ddim:
630-
devices.test_for_nans(x, "vae")
639+
#for x in x_samples_ddim:
640+
# devices.test_for_nans(x, "vae")
631641

632642
x_samples_ddim = torch.stack(x_samples_ddim).float()
633643
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
644+
#x_samples_ddim = 255.0 * rearrange(x_samples_ddim, "1 c h w -> h w c")
634645

635646
del samples_ddim
636647

@@ -645,7 +656,7 @@ def get_conds_with_caching(function, required_prompts, steps, cache):
645656
for i, x_sample in enumerate(x_samples_ddim):
646657
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
647658
x_sample = x_sample.astype(np.uint8)
648-
659+
#x_sample = 255.0 * rearrange(x_sample, "1 c h w -> h w c")
649660
if p.restore_faces:
650661
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
651662
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
@@ -868,8 +879,8 @@ def save_intermediate(image, index):
868879
save_intermediate(image, i)
869880

870881
image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
871-
image = np.array(image).astype(np.float32) / 255.0
872-
image = np.moveaxis(image, 2, 0)
882+
image = np.array(image).astype(np.float32) / 255.0 - 1
883+
#image = np.moveaxis(image, 2, 0)
873884
batch_images.append(image)
874885

875886
decoded_samples = torch.from_numpy(np.array(batch_images))
@@ -901,7 +912,7 @@ def save_intermediate(image, index):
901912
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
902913
sampler = None
903914

904-
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
915+
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, image_cfg_scale: float = 7.5, initial_noise_multiplier: float = None, **kwargs):
905916
super().__init__(**kwargs)
906917

907918
self.init_images = init_images
@@ -916,6 +927,7 @@ def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_str
916927
self.inpaint_full_res = inpaint_full_res
917928
self.inpaint_full_res_padding = inpaint_full_res_padding
918929
self.inpainting_mask_invert = inpainting_mask_invert
930+
self.image_cfg_scale=image_cfg_scale
919931
self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
920932
self.mask = None
921933
self.nmask = None
@@ -983,9 +995,16 @@ def init(self, all_prompts, all_seeds, all_subseeds):
983995

984996
if add_color_corrections:
985997
self.color_corrections.append(setup_color_correction(image))
986-
987-
image = np.array(image).astype(np.float32) / 255.0
988-
image = np.moveaxis(image, 2, 0)
998+
width, height = image.size
999+
factor = self.width / max(width, height)
1000+
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
1001+
width = int((width * factor) // 64) * 64
1002+
height = int((height * factor) // 64) * 64
1003+
image = ImageOps.fit(image, (width, height), method=Image.Resampling.LANCZOS)
1004+
1005+
#image = 2 * torch.tensor(np.array(image)).float() / 255 - 1
1006+
#image = np.array(image).astype(np.float32) / 255.0
1007+
#image = np.moveaxis(image, 2, 0)
9891008

9901009
imgs.append(image)
9911010

@@ -1002,10 +1021,22 @@ def init(self, all_prompts, all_seeds, all_subseeds):
10021021
batch_images = np.array(imgs)
10031022
else:
10041023
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
1005-
1006-
image = torch.from_numpy(batch_images)
1007-
image = 2. * image - 1.
1008-
image = image.to(shared.device)
1024+
1025+
#image = torch.from_numpy(batch_images)
1026+
#width, height = image.size
1027+
#factor = 512 / max(width, height)
1028+
###factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
1029+
#width = int((width * factor) // 64) * 64
1030+
#height = int((height * factor) // 64) * 64
1031+
#image = ImageOps.fit(image, (width, height), method=Image.Resampling.LANCZOS)
1032+
##image = 2. * image - 1.
1033+
#image = rearrange(image, "h w c -> 1 c h w")
1034+
#image = image.to(shared.device)
1035+
#image = torch.from_numpy(batch_images)
1036+
#image = 2. * image - 1.
1037+
image = 2 * torch.tensor(np.array(image)).float() / 255 - 1
1038+
image = rearrange(image, "h w c -> 1 c h w").to(shared.device)
1039+
#image = image.to(shared.device)
10091040

10101041
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
10111042

@@ -1032,6 +1063,7 @@ def init(self, all_prompts, all_seeds, all_subseeds):
10321063
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
10331064

10341065
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1066+
10351067
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
10361068

10371069
if self.initial_noise_multiplier != 1.0:

modules/sd_samplers_common.py

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from PIL import Image
55
import torchsde._brownian.brownian_interval
6+
from einops import rearrange
67
from modules import devices, processing, images, sd_vae_approx
78

89
from modules.shared import opts, state
@@ -38,8 +39,10 @@ def single_sample_to_image(sample, approximation=None):
3839
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
3940

4041
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
42+
x_sample = 255.0 * rearrange(x_sample, "1 c h w -> h w c")
4143
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
4244
x_sample = x_sample.astype(np.uint8)
45+
4346
return Image.fromarray(x_sample)
4447

4548

modules/sd_samplers_kdiffusion.py

+38-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import deque
22
import torch
33
import inspect
4+
import einops
45
import k_diffusion.sampling
56
from modules import prompt_parser, devices, sd_samplers_common
67

@@ -57,17 +58,17 @@ def __init__(self, model):
5758
self.init_latent = None
5859
self.step = 0
5960

60-
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
61+
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, image_scale):
6162
denoised_uncond = x_out[-uncond.shape[0]:]
6263
denoised = torch.clone(denoised_uncond)
6364

6465
for i, conds in enumerate(conds_list):
6566
for cond_index, weight in conds:
66-
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
67-
67+
denoised[i] += cond_scale * (x_out[cond_index] - denoised_uncond[i]) + image_scale * (denoised_uncond[i] - x_out[cond_index])
68+
6869
return denoised
6970

70-
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
71+
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond, image_scale):
7172
if state.interrupted or state.skipped:
7273
raise sd_samplers_common.InterruptedException
7374

@@ -76,10 +77,9 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
7677

7778
batch_size = len(conds_list)
7879
repeats = [len(conds_list[i]) for i in range(batch_size)]
79-
80-
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
81-
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
82-
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
80+
x_in = einops.repeat(x, "1 ... -> n ...", n=3)
81+
sigma_in = einops.repeat(sigma, "1 ... -> n ...", n=3)
82+
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [image_cond])
8383

8484
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
8585
cfg_denoiser_callback(denoiser_params)
@@ -88,7 +88,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
8888
sigma_in = denoiser_params.sigma
8989

9090
if tensor.shape[1] == uncond.shape[1]:
91-
cond_in = torch.cat([tensor, uncond])
91+
cond_in = torch.cat([tensor, uncond, uncond])
9292

9393
if shared.batch_cond_uncond:
9494
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
@@ -115,7 +115,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
115115
elif opts.live_preview_content == "Negative prompt":
116116
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
117117

118-
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
118+
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, image_scale)
119119

120120
if self.mask is not None:
121121
denoised = self.init_latent * self.mask + self.nmask * denoised
@@ -124,6 +124,32 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
124124

125125
return denoised
126126

127+
class CFGDenoiserIp2p(torch.nn.Module):
128+
def __init__(self, model):
129+
super().__init__()
130+
self.inner_model = model
131+
self.mask = None
132+
self.nmask = None
133+
self.init_latent = None
134+
self.step = 0
135+
136+
def forward(self, z, sigma, uncond, cond, cond_scale, image_cond):
137+
if state.interrupted or state.skipped:
138+
raise sd_samplers_common.InterruptedException
139+
cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
140+
cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
141+
image_cond_in = einops.repeat(image_cond, "1 ... -> n ...", n=3)
142+
143+
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
144+
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
145+
cond_in = torch.cat([tensor, uncond])
146+
cfg_cond = {
147+
"c_crossattn": [cond_in],
148+
"c_concat": [image_cond_in],
149+
}
150+
out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
151+
return out_uncond + cond_scale * (out_cond - out_img_cond) + 1.5 * (out_img_cond - out_uncond)
152+
127153

128154
class TorchHijack:
129155
def __init__(self, sampler_noises):
@@ -265,7 +291,8 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning,
265291
'cond': conditioning,
266292
'image_cond': image_conditioning,
267293
'uncond': unconditional_conditioning,
268-
'cond_scale': p.cfg_scale
294+
'cond_scale': p.cfg_scale,
295+
'image_scale': p.image_cfg_scale
269296
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
270297

271298
return samples

modules/ui.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,8 @@ def copy_image(img):
765765

766766
elif category == "cfg":
767767
with FormGroup():
768-
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
768+
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.5, elem_id="img2img_cfg_scale")
769+
image_cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='Image CFG Scale', value=1.5, elem_id="img2img_cfg_scale")
769770
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
770771

771772
elif category == "seed":
@@ -861,6 +862,7 @@ def select_img2img_tab(tab):
861862
batch_count,
862863
batch_size,
863864
cfg_scale,
865+
image_cfg_scale,
864866
denoising_strength,
865867
seed,
866868
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,

0 commit comments

Comments
 (0)