@@ -115,20 +115,17 @@ def txt2img_image_conditioning(sd_model, x, width, height):
115
115
return x .new_zeros (x .shape [0 ], 2 * sd_model .noise_augmentor .time_embed .dim , dtype = x .dtype , device = x .device )
116
116
117
117
else :
118
- sd = sd_model .model .state_dict ()
119
- diffusion_model_input = sd .get ('diffusion_model.input_blocks.0.0.weight' , None )
120
- if diffusion_model_input is not None :
121
- if diffusion_model_input .shape [1 ] == 9 :
122
- # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
123
- image_conditioning = torch .ones (x .shape [0 ], 3 , height , width , device = x .device ) * 0.5
124
- image_conditioning = images_tensor_to_samples (image_conditioning ,
125
- approximation_indexes .get (opts .sd_vae_encode_method ))
118
+ if sd_model .model .is_sdxl_inpaint :
119
+ # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
120
+ image_conditioning = torch .ones (x .shape [0 ], 3 , height , width , device = x .device ) * 0.5
121
+ image_conditioning = images_tensor_to_samples (image_conditioning ,
122
+ approximation_indexes .get (opts .sd_vae_encode_method ))
126
123
127
- # Add the fake full 1s mask to the first dimension.
128
- image_conditioning = torch .nn .functional .pad (image_conditioning , (0 , 0 , 0 , 0 , 1 , 0 ), value = 1.0 )
129
- image_conditioning = image_conditioning .to (x .dtype )
124
+ # Add the fake full 1s mask to the first dimension.
125
+ image_conditioning = torch .nn .functional .pad (image_conditioning , (0 , 0 , 0 , 0 , 1 , 0 ), value = 1.0 )
126
+ image_conditioning = image_conditioning .to (x .dtype )
130
127
131
- return image_conditioning
128
+ return image_conditioning
132
129
133
130
# Dummy zero conditioning if we're not using inpainting or unclip models.
134
131
# Still takes up a bit of memory, but no encoder call.
@@ -390,11 +387,8 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None
390
387
if self .sampler .conditioning_key == "crossattn-adm" :
391
388
return self .unclip_image_conditioning (source_image )
392
389
393
- sd = self .sampler .model_wrap .inner_model .model .state_dict ()
394
- diffusion_model_input = sd .get ('diffusion_model.input_blocks.0.0.weight' , None )
395
- if diffusion_model_input is not None :
396
- if diffusion_model_input .shape [1 ] == 9 :
397
- return self .inpainting_image_conditioning (source_image , latent_image , image_mask = image_mask )
390
+ if self .sampler .model_wrap .inner_model .model .is_sdxl_inpaint :
391
+ return self .inpainting_image_conditioning (source_image , latent_image , image_mask = image_mask )
398
392
399
393
# Dummy zero conditioning if we're not using inpainting or depth model.
400
394
return latent_image .new_zeros (latent_image .shape [0 ], 5 , 1 , 1 )
0 commit comments