From 4aa115d5c167b03df5ad0fd2e966f5e4bc99233e Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Sun, 2 Apr 2023 17:28:44 +0800 Subject: [PATCH 1/9] Add bf16 support. --- modules/processing.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 6d9c6a8de29..ce0dbbabf35 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -653,8 +653,20 @@ def get_conds_with_caching(function, required_prompts, steps, cache): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] - for x in x_samples_ddim: - devices.test_for_nans(x, "vae") + try: + for x in x_samples_ddim: + devices.test_for_nans(x, "vae") + except devices.NansException as e: + if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae and torch.cuda.get_device_capability()[0] >= 8: + print('\nA tensor with all NaNs was produced in VAE, try converting to bf16.') + devices.dtype_vae = torch.bfloat16 + vae_file, vae_source = sd_vae.resolve_vae(p.sd_model.sd_model_checkpoint) + sd_vae.load_vae(p.sd_model, vae_file, vae_source) + x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] + for x in x_samples_ddim: + devices.test_for_nans(x, "vae") + else: + raise e x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) From 157c25f123bfa24aa5d72d700016de5642fcb715 Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Sun, 2 Apr 2023 17:41:30 +0800 Subject: [PATCH 2/9] Restore type --- modules/sd_vae.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 9b00f76e9c6..707d1fb2c06 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -183,6 +183,8 @@ def clear_loaded_vae(): def reload_vae_weights(sd_model=None, vae_file=unspecified): from modules import lowvram, devices, sd_hijack + if devices.dtype_vae == torch.bfloat16: + devices.dtype_vae = torch.float16 if not sd_model: sd_model = shared.sd_model From d19d227138c6f0448849ca7b19ac8a8e876f249c Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Thu, 6 Apr 2023 19:52:18 +0800 Subject: [PATCH 3/9] Add startup parameters and version check --- modules/cmd_args.py | 1 + modules/processing.py | 2 +- modules/sd_vae.py | 2 +- webui.py | 10 ++++++++++ 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 81c0b82a307..547e8dc89d2 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -101,3 +101,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) +parser.add_argument("--rollback-vae", action='store_true', help="trying to roll back vae when produced nan image, need to enable nan check", default=False) diff --git a/modules/processing.py b/modules/processing.py index ce0dbbabf35..98402aa57f9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -657,7 +657,7 @@ def get_conds_with_caching(function, required_prompts, steps, cache): for x in x_samples_ddim: devices.test_for_nans(x, "vae") except devices.NansException as e: - if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae and torch.cuda.get_device_capability()[0] >= 8: + if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae and shared.cmd_opts.rollback_vae: print('\nA tensor with all NaNs was produced in VAE, try converting to bf16.') devices.dtype_vae = torch.bfloat16 vae_file, vae_source = sd_vae.resolve_vae(p.sd_model.sd_model_checkpoint) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 707d1fb2c06..ee3902a4b5c 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -183,7 +183,7 @@ def clear_loaded_vae(): def reload_vae_weights(sd_model=None, vae_file=unspecified): from modules import lowvram, devices, sd_hijack - if devices.dtype_vae == torch.bfloat16: + if shared.cmd_opts.rollback_vae and devices.dtype_vae == torch.bfloat16: devices.dtype_vae = torch.float16 if not sd_model: sd_model = shared.sd_model diff --git a/webui.py b/webui.py index b570895fb2c..2f8a3e9fc21 100644 --- a/webui.py +++ b/webui.py @@ -97,9 +97,19 @@ def check_versions(): Use --skip-version-check commandline argument to disable this check. """.strip()) +def check_rollback_vae(): + if shared.cmd_opts.rollback_vae: + if version.parse(torch.__version__) < version.parse('2.1'): + print("If your PyTorch version is lower than PyTorch 2.1, Rollback VAE will not work.") + shared.cmd_opts.rollback_vae = False + elif 0 < torch.cuda.get_device_capability()[0] < 8: + print('Rollback VAE will not work because your device does not support it.') + shared.cmd_opts.rollback_vae = False + def initialize(): check_versions() + check_rollback_vae() extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) From 942c7d6158a160bf675ef8d0ce2630318edb827c Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Sat, 8 Apr 2023 23:50:22 +0800 Subject: [PATCH 4/9] Bug fix --- modules/sd_vae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index ee3902a4b5c..8cff159161b 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -183,8 +183,6 @@ def clear_loaded_vae(): def reload_vae_weights(sd_model=None, vae_file=unspecified): from modules import lowvram, devices, sd_hijack - if shared.cmd_opts.rollback_vae and devices.dtype_vae == torch.bfloat16: - devices.dtype_vae = torch.float16 if not sd_model: sd_model = shared.sd_model @@ -205,6 +203,8 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): sd_model.to(devices.cpu) sd_hijack.model_hijack.undo_hijack(sd_model) + if shared.cmd_opts.rollback_vae and devices.dtype_vae == torch.bfloat16: + devices.dtype_vae = torch.float16 load_vae(sd_model, vae_file, vae_source) From 236fd989f00f52f113291eea4ce602df46242eb8 Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Sun, 30 Apr 2023 00:35:17 +0800 Subject: [PATCH 5/9] Exclude CPU devices --- webui.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 2f8a3e9fc21..0267fc63a91 100644 --- a/webui.py +++ b/webui.py @@ -97,8 +97,12 @@ def check_versions(): Use --skip-version-check commandline argument to disable this check. """.strip()) + def check_rollback_vae(): - if shared.cmd_opts.rollback_vae: + if devices.device == devices.cpu: + print("Rollback VAE does not support CPU devices and will not work.") + shared.cmd_opts.rollback_vae = False + elif shared.cmd_opts.rollback_vae: if version.parse(torch.__version__) < version.parse('2.1'): print("If your PyTorch version is lower than PyTorch 2.1, Rollback VAE will not work.") shared.cmd_opts.rollback_vae = False From c5b03fae510218337d8ebbfbf5756bd4d43a779b Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Tue, 9 May 2023 23:11:13 +0800 Subject: [PATCH 6/9] Remove startup parameters --- modules/processing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 98402aa57f9..56c565906b1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -11,6 +11,7 @@ import cv2 from skimage import exposure from typing import Any, Dict, List, Optional +from packaging import version import modules.sd_hijack from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts @@ -657,7 +658,7 @@ def get_conds_with_caching(function, required_prompts, steps, cache): for x in x_samples_ddim: devices.test_for_nans(x, "vae") except devices.NansException as e: - if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae and shared.cmd_opts.rollback_vae: + if devices.dtype_vae == torch.float16 and version.parse(torch.__version__) >= version.parse('2.1') and torch.cuda.is_bf16_supported(): print('\nA tensor with all NaNs was produced in VAE, try converting to bf16.') devices.dtype_vae = torch.bfloat16 vae_file, vae_source = sd_vae.resolve_vae(p.sd_model.sd_model_checkpoint) From 7d4561b7a7a130970fa28aab702a0a6da6750230 Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Tue, 9 May 2023 23:16:04 +0800 Subject: [PATCH 7/9] Revert "Add startup parameters and version check" This reverts commit d19d2271 --- modules/cmd_args.py | 1 - webui.py | 14 -------------- 2 files changed, 15 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 547e8dc89d2..81c0b82a307 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -101,4 +101,3 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) -parser.add_argument("--rollback-vae", action='store_true', help="trying to roll back vae when produced nan image, need to enable nan check", default=False) diff --git a/webui.py b/webui.py index 0267fc63a91..b570895fb2c 100644 --- a/webui.py +++ b/webui.py @@ -98,22 +98,8 @@ def check_versions(): """.strip()) -def check_rollback_vae(): - if devices.device == devices.cpu: - print("Rollback VAE does not support CPU devices and will not work.") - shared.cmd_opts.rollback_vae = False - elif shared.cmd_opts.rollback_vae: - if version.parse(torch.__version__) < version.parse('2.1'): - print("If your PyTorch version is lower than PyTorch 2.1, Rollback VAE will not work.") - shared.cmd_opts.rollback_vae = False - elif 0 < torch.cuda.get_device_capability()[0] < 8: - print('Rollback VAE will not work because your device does not support it.') - shared.cmd_opts.rollback_vae = False - - def initialize(): check_versions() - check_rollback_vae() extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) From c48797e07e0d5533d033aa2043d939580ec0cfca Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Tue, 9 May 2023 23:25:01 +0800 Subject: [PATCH 8/9] Logic adjustments --- modules/sd_vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 8cff159161b..11fc8569f93 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -203,7 +203,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): sd_model.to(devices.cpu) sd_hijack.model_hijack.undo_hijack(sd_model) - if shared.cmd_opts.rollback_vae and devices.dtype_vae == torch.bfloat16: + if not shared.cmd_opts.disable_nan_check and devices.device != devices.cpu and devices.dtype_vae == torch.bfloat16: devices.dtype_vae = torch.float16 load_vae(sd_model, vae_file, vae_source) From 6f1680110eb8e64846d99c9c2d53adee8b3ade89 Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Tue, 9 May 2023 23:51:55 +0800 Subject: [PATCH 9/9] Add description --- modules/cmd_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 81c0b82a307..8749dab23b2 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -61,7 +61,7 @@ parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*") parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") -parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") +parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans and disable vae auto-conversion; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)