Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bf16 support for VAE as a fallback #9295

Closed
wants to merge 9 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modules/cmd_args.py
Original file line number Diff line number Diff line change
@@ -61,7 +61,7 @@
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans and disable vae auto-conversion; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
17 changes: 15 additions & 2 deletions modules/processing.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
import cv2
from skimage import exposure
from typing import Any, Dict, List, Optional
from packaging import version

import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
@@ -653,8 +654,20 @@ def get_conds_with_caching(function, required_prompts, steps, cache):
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)

x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
for x in x_samples_ddim:
devices.test_for_nans(x, "vae")
try:
for x in x_samples_ddim:
devices.test_for_nans(x, "vae")
except devices.NansException as e:
if devices.dtype_vae == torch.float16 and version.parse(torch.__version__) >= version.parse('2.1') and torch.cuda.is_bf16_supported():
print('\nA tensor with all NaNs was produced in VAE, try converting to bf16.')
devices.dtype_vae = torch.bfloat16
vae_file, vae_source = sd_vae.resolve_vae(p.sd_model.sd_model_checkpoint)
sd_vae.load_vae(p.sd_model, vae_file, vae_source)
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
for x in x_samples_ddim:
devices.test_for_nans(x, "vae")
else:
raise e

x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
2 changes: 2 additions & 0 deletions modules/sd_vae.py
Original file line number Diff line number Diff line change
@@ -203,6 +203,8 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
sd_model.to(devices.cpu)

sd_hijack.model_hijack.undo_hijack(sd_model)
if not shared.cmd_opts.disable_nan_check and devices.device != devices.cpu and devices.dtype_vae == torch.bfloat16:
devices.dtype_vae = torch.float16

load_vae(sd_model, vae_file, vae_source)