@@ -115,17 +115,20 @@ def txt2img_image_conditioning(sd_model, x, width, height):
115
115
return x .new_zeros (x .shape [0 ], 2 * sd_model .noise_augmentor .time_embed .dim , dtype = x .dtype , device = x .device )
116
116
117
117
else :
118
- if sd_model .model .is_sdxl_inpaint :
119
- # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
120
- image_conditioning = torch .ones (x .shape [0 ], 3 , height , width , device = x .device ) * 0.5
121
- image_conditioning = images_tensor_to_samples (image_conditioning ,
122
- approximation_indexes .get (opts .sd_vae_encode_method ))
118
+ sd = sd_model .model .state_dict ()
119
+ diffusion_model_input = sd .get ('diffusion_model.input_blocks.0.0.weight' , None )
120
+ if diffusion_model_input is not None :
121
+ if diffusion_model_input .shape [1 ] == 9 :
122
+ # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
123
+ image_conditioning = torch .ones (x .shape [0 ], 3 , height , width , device = x .device ) * 0.5
124
+ image_conditioning = images_tensor_to_samples (image_conditioning ,
125
+ approximation_indexes .get (opts .sd_vae_encode_method ))
123
126
124
- # Add the fake full 1s mask to the first dimension.
125
- image_conditioning = torch .nn .functional .pad (image_conditioning , (0 , 0 , 0 , 0 , 1 , 0 ), value = 1.0 )
126
- image_conditioning = image_conditioning .to (x .dtype )
127
+ # Add the fake full 1s mask to the first dimension.
128
+ image_conditioning = torch .nn .functional .pad (image_conditioning , (0 , 0 , 0 , 0 , 1 , 0 ), value = 1.0 )
129
+ image_conditioning = image_conditioning .to (x .dtype )
127
130
128
- return image_conditioning
131
+ return image_conditioning
129
132
130
133
# Dummy zero conditioning if we're not using inpainting or unclip models.
131
134
# Still takes up a bit of memory, but no encoder call.
@@ -387,8 +390,11 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None
387
390
if self .sampler .conditioning_key == "crossattn-adm" :
388
391
return self .unclip_image_conditioning (source_image )
389
392
390
- if self .sampler .model_wrap .inner_model .model .is_sdxl_inpaint :
391
- return self .inpainting_image_conditioning (source_image , latent_image , image_mask = image_mask )
393
+ sd = self .sampler .model_wrap .inner_model .model .state_dict ()
394
+ diffusion_model_input = sd .get ('diffusion_model.input_blocks.0.0.weight' , None )
395
+ if diffusion_model_input is not None :
396
+ if diffusion_model_input .shape [1 ] == 9 :
397
+ return self .inpainting_image_conditioning (source_image , latent_image , image_mask = image_mask )
392
398
393
399
# Dummy zero conditioning if we're not using inpainting or depth model.
394
400
return latent_image .new_zeros (latent_image .shape [0 ], 5 , 1 , 1 )
0 commit comments