Skip to content

Commit

Permalink
Merge branch 'fqt' into feffy380
Browse files Browse the repository at this point in the history
  • Loading branch information
feffy380 committed Oct 13, 2024
2 parents 565a9e2 + 3244c47 commit f7543bd
Show file tree
Hide file tree
Showing 12 changed files with 497 additions and 82 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
logs
logs.old
__pycache__
wd14_tagger_model
venv
Expand Down
4 changes: 2 additions & 2 deletions library/attention_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ def backward(ctx, do):
# forward-only flash attention for Navi
class FlashAttnFuncNavi(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
def forward(ctx, q, k, v, mask=None, causal=False, scale=None):
dropout_p = 0.0
softmax_scale = q.shape[-1] ** (-0.5)
softmax_scale = q.shape[-1] ** (-0.5) if scale is None else scale
return_softmax = False

q, k, v = (rearrange(t, "b h n d -> b n h d") for t in (q, k, v))
Expand Down
9 changes: 9 additions & 0 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device):
learned_weights /= learned_weights.mean()
noise_scheduler.learned_weights = learned_weights.to(device)

# laplace weights
mu = 1.5
b = 2.5
# mu = 0
# b = 0.75
laplace_weights = ((all_snr.log() - mu).abs() / -b).exp() / (2 * b)
laplace_weights /= laplace_weights.mean()
noise_scheduler.laplace_weights = laplace_weights.to(device)


def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
# fix beta: zero terminal SNR
Expand Down
9 changes: 5 additions & 4 deletions library/original_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,9 @@ def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))

if not torch.is_grad_enabled() and flash_attn_installed and q.shape[-1] <= 128:
flash_func = FlashAttnFuncNavi
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
out = FlashAttnFuncNavi.apply(q, k, v, mask, False)
else:
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)

out = rearrange(out, "b h n d -> b n (h d)")

Expand All @@ -563,8 +564,8 @@ def forward_sdpa(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in

if not torch.is_grad_enabled() and flash_attn_installed and q.shape[-1] <= 128:
out = FlashAttnFuncNavi.apply(q, k, v, mask, False, 512, 1024)
if not torch.is_grad_enabled() and flash_attn_installed and q.shape[-1] <= 512:
out = FlashAttnFuncNavi.apply(q, k, v, mask, False)
else:
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)

Expand Down
10 changes: 6 additions & 4 deletions library/sdxl_original_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torch.nn import functional as F
from einops import rearrange
from library.attention_processors import FlashAttentionFunction, FlashAttnFuncNavi, flash_attn_installed
# from library.flash_attn_wmma.attention import FlashAttentionWMMA
from .utils import setup_logging

setup_logging()
Expand Down Expand Up @@ -349,8 +350,9 @@ def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))

if not torch.is_grad_enabled() and flash_attn_installed and q.shape[-1] <= 128:
flash_func = FlashAttnFuncNavi
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
out = FlashAttnFuncNavi.apply(q, k, v, mask, False)
else:
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)

out = rearrange(out, "b h n d -> b n (h d)")

Expand All @@ -368,8 +370,8 @@ def forward_sdpa(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in

if not torch.is_grad_enabled() and flash_attn_installed and q.shape[-1] <= 128:
out = FlashAttnFuncNavi.apply(q, k, v, mask, False, 512, 1024)
if not torch.is_grad_enabled() and flash_attn_installed and q.shape[-1] <= 512:
out = FlashAttnFuncNavi.apply(q, k, v, mask, False)
else:
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)

Expand Down
11 changes: 6 additions & 5 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,6 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()

# apply token merging patch
if args.todo_factor:
token_downsampling.apply_patch(unet, args, is_sdxl=True)
logger.info(f"enable token downsampling optimization: downsample_factor={args.todo_factor}, max_depth={args.todo_max_depth}")

return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info


Expand Down Expand Up @@ -347,6 +342,12 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
action="store_true",
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
)
parser.add_argument(
"--sdxl_cond_dropout_rate",
type=float,
default=0,
help="rate (0-1) at which to drop out the text conditioning for classifier-free guidance in SDXL"
)


def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
Expand Down
158 changes: 158 additions & 0 deletions library/srgb_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import os
from typing import TypeAlias, Mapping
from io import BytesIO

try:
import pillow_jxl # Ensure this is installed
except:
from jxlpy import JXLImagePlugin
from PIL import ImageCms, Image, PngImagePlugin, ImageChops
from PIL.ImageCms import Intent
import torch
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode
import torchvision.transforms.functional as TF

# Suppress the warning for large images
Image.MAX_IMAGE_PIXELS = None
PngImagePlugin.MAX_TEXT_CHUNK = 100 * (1024**2)

# Color management profiles and intent flags
_SRGB = ImageCms.createProfile(colorSpace='sRGB')

IntentFlags: TypeAlias = Mapping[Intent, int]

_INTENT_FLAGS_INITIAL: IntentFlags = {
Intent.PERCEPTUAL: ImageCms.FLAGS["HIGHRESPRECALC"],
Intent.RELATIVE_COLORIMETRIC: ImageCms.FLAGS["HIGHRESPRECALC"] | ImageCms.FLAGS["BLACKPOINTCOMPENSATION"],
Intent.SATURATION: ImageCms.FLAGS["HIGHRESPRECALC"],
Intent.ABSOLUTE_COLORIMETRIC: ImageCms.FLAGS["HIGHRESPRECALC"]
}

_INTENT_FLAGS_FALLBACK: IntentFlags = {
Intent.PERCEPTUAL: ImageCms.FLAGS["HIGHRESPRECALC"],
Intent.RELATIVE_COLORIMETRIC: ImageCms.FLAGS["HIGHRESPRECALC"] | ImageCms.FLAGS["BLACKPOINTCOMPENSATION"],
Intent.ABSOLUTE_COLORIMETRIC: ImageCms.FLAGS["HIGHRESPRECALC"]
}

def _coalesce_intent(intent: Intent | int) -> Intent:
if isinstance(intent, Intent):
return intent

match intent:
case 0:
return Intent.PERCEPTUAL
case 1:
return Intent.RELATIVE_COLORIMETRIC
case 2:
return Intent.SATURATION
case 3:
return Intent.ABSOLUTE_COLORIMETRIC
case _:
raise ValueError("invalid ImageCms intent")

def open_srgb(
fp,
*,
mode: str | None = "RGB",
intent: Intent | int | None = Intent.RELATIVE_COLORIMETRIC,
intent_flags: IntentFlags | None = None,
intent_fallback: bool = True,
formats: list[str] | tuple[str, ...] | None = None,
) -> Image.Image:
img = Image.open(fp, formats=formats)

if img.mode == 'P' and img.info.get('transparency'):
img = img.convert('PA')

if mode is None:
match img.mode:
case "RGBA" | "LA" | "PA":
mode = "RGBA"
case "RGBa" | "La":

Check warning on line 72 in library/srgb_util.py

View workflow job for this annotation

GitHub Actions / build

"Ba" should be "By" or "Be".
mode = "RGBa"

Check warning on line 73 in library/srgb_util.py

View workflow job for this annotation

GitHub Actions / build

"Ba" should be "By" or "Be".
case _:
mode = "RGB"

# ensure image is in sRGB color space
if intent is not None:
icc_raw = img.info.get("icc_profile")

if icc_raw is not None:
profile = ImageCms.ImageCmsProfile(BytesIO(icc_raw))
intent = _coalesce_intent(intent)

if img.mode == "P":
img = img.convert("RGB")
elif img.mode == "PA":
img = img.convert("RGBA")

color_profile_sus = False
color_mode_corrected = False
mode_conversion = {
('RGBA', 'GRAY'): 'LA',
('RGB', 'GRAY'): 'L',
('LA', 'RGB '): 'RGBA',
('L', 'RGB '): 'RGB',
('I;16', 'RGB '): 'RGB',
('RGB', 'CMYK'): 'CMYK'

}
valid_modes = [
('RGBA', 'RGB '),
('RGB', 'RGB '),
('LA', 'GRAY'),
('L', 'GRAY'),
('I;16', 'GRAY'),
('CMYK', 'CMYK')
]

if (img.mode, profile.profile.xcolor_space) not in valid_modes:
if (img.mode, profile.profile.xcolor_space) in mode_conversion:
img = img.convert(mode_conversion[(img.mode, profile.profile.xcolor_space)])
color_mode_corrected = True
else:
print(f"WARNING: {fp} has unhandled color space mismatch: '{profile.profile.xcolor_space}' != '{img.mode}'")
color_profile_sus = True

intent_issue = False
if intent_fallback and not profile.profile.is_intent_supported(intent, ImageCms.Direction.INPUT):
intent = _coalesce_intent(ImageCms.getDefaultIntent(profile))
if not not profile.profile.is_intent_supported(intent, ImageCms.Direction.INPUT):
print("Warning: This profile doesn't support any operations!")
intent_issue = True
flags = (intent_flags if intent_flags is not None else _INTENT_FLAGS_FALLBACK).get(intent)
else:
flags = (intent_flags if intent_flags is not None else _INTENT_FLAGS_INITIAL).get(intent)

if flags is None:
raise KeyError(f"no flags for intent {intent}")

try:
if img.mode == mode:
ImageCms.profileToProfile(
img,
profile,
_SRGB,
renderingIntent=intent,
inPlace=True,
flags=flags
)
else:
img = ImageCms.profileToProfile(
img,
profile,
_SRGB,
renderingIntent=intent,
outputMode=mode,
flags=flags
)
if color_profile_sus and not color_mode_corrected:
print(f"WARNING: {fp} had a mismatched color profile but loaded fine.")
except:
print(f"WARNING: Failed to load color profile for {fp}. Is it corrupt, or are we mishandling an edge case?")

if img.mode != mode:
img = img.convert(mode)

return img
18 changes: 12 additions & 6 deletions library/token_downsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
import torch.nn.functional as F


def up_or_downsample(item, cur_w, cur_h, new_w, new_h, method="nearest-exact"):
def up_or_downsample(item, cur_w, cur_h, new_w, new_h, downsample_factor, method="nearest-exact"):
batch_size = item.shape[0]
scale_factor = int(downsample_factor)

item = item.reshape(batch_size, cur_h, cur_w, -1).permute(0, 3, 1, 2)
item = F.interpolate(item, size=(new_h, new_w), mode=method).permute(0, 2, 3, 1)
# item = F.interpolate(item, size=(new_h, new_w), mode=method).permute(0, 2, 3, 1)
item = item[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1)
item = item.reshape(batch_size, new_h * new_w, -1)

return item
Expand All @@ -32,7 +34,7 @@ def compute_merge(x: torch.Tensor, todo_info: dict):
downsample_factor = args["downsample_factor"][downsample]
new_h = int(cur_h / downsample_factor)
new_w = int(cur_w / downsample_factor)
merge_op = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h)
merge_op = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h, downsample_factor)

return merge_op

Expand Down Expand Up @@ -99,9 +101,13 @@ def apply_patch(unet: torch.nn.Module, args, is_sdxl=False):
hook_unet(unet)

for _, module in unet.named_modules():
if module.__class__.__name__ == "BasicTransformerBlock":
module.attn1._todo_info = unet._todo_info
hook_attention(module.attn1)
if module.__class__.__name__ == "Transformer2DModel":
if is_sdxl and len(module.transformer_blocks) != 2:
continue
for _, submodule in module.named_modules():
if submodule.__class__.__name__ == "BasicTransformerBlock":
submodule.attn1._todo_info = unet._todo_info
hook_attention(submodule.attn1)

return unet

Expand Down
Loading

0 comments on commit f7543bd

Please sign in to comment.