diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 81c0b82a307..8749dab23b2 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -61,7 +61,7 @@ parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*") parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") -parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") +parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans and disable vae auto-conversion; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) diff --git a/modules/processing.py b/modules/processing.py index 6d9c6a8de29..56c565906b1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -11,6 +11,7 @@ import cv2 from skimage import exposure from typing import Any, Dict, List, Optional +from packaging import version import modules.sd_hijack from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts @@ -653,8 +654,20 @@ def get_conds_with_caching(function, required_prompts, steps, cache): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] - for x in x_samples_ddim: - devices.test_for_nans(x, "vae") + try: + for x in x_samples_ddim: + devices.test_for_nans(x, "vae") + except devices.NansException as e: + if devices.dtype_vae == torch.float16 and version.parse(torch.__version__) >= version.parse('2.1') and torch.cuda.is_bf16_supported(): + print('\nA tensor with all NaNs was produced in VAE, try converting to bf16.') + devices.dtype_vae = torch.bfloat16 + vae_file, vae_source = sd_vae.resolve_vae(p.sd_model.sd_model_checkpoint) + sd_vae.load_vae(p.sd_model, vae_file, vae_source) + x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] + for x in x_samples_ddim: + devices.test_for_nans(x, "vae") + else: + raise e x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 9b00f76e9c6..11fc8569f93 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -203,6 +203,8 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): sd_model.to(devices.cpu) sd_hijack.model_hijack.undo_hijack(sd_model) + if not shared.cmd_opts.disable_nan_check and devices.device != devices.cpu and devices.dtype_vae == torch.bfloat16: + devices.dtype_vae = torch.float16 load_vae(sd_model, vae_file, vae_source)