Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adafactor fused backward pass and optimizer step, lowers SDXL (@ 1024 resolution) VRAM usage to BF16(10GB)/FP32(16.4GB) #1259

Merged
merged 1 commit into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions library/adafactor_fused.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import math
import torch
from transformers import Adafactor

@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
return
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")

state = self.state[p]
grad_shape = grad.shape

factored, use_first_moment = Adafactor._get_options(group, grad_shape)
# State Initialization
if len(state) == 0:
state["step"] = 0

if use_first_moment:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
if factored:
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
else:
state["exp_avg_sq"] = torch.zeros_like(grad)

state["RMS"] = 0
else:
if use_first_moment:
state["exp_avg"] = state["exp_avg"].to(grad)
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
else:
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)

p_data_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()

state["step"] += 1
state["RMS"] = Adafactor._rms(p_data_fp32)
lr = Adafactor._get_lr(group, state)

beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad ** 2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]

exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))

# Approximation of exponential moving average of square of gradient
update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]

exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)

update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
update.mul_(lr)

if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
update = exp_avg

if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))

p_data_fp32.add_(-update)

if p.dtype in {torch.float16, torch.bfloat16}:
p.copy_(p_data_fp32)


@torch.no_grad()
def adafactor_step(self, closure=None):
"""
Performs a single optimization step

Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
adafactor_step_param(self, p, group)

return loss

def patch_adafactor_fused(optimizer: Adafactor):
optimizer.step_param = adafactor_step_param.__get__(optimizer)
optimizer.step = adafactor_step.__get__(optimizer)
13 changes: 13 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2920,6 +2920,11 @@
default=1,
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
)
parser.add_argument(
"--fused_backward_pass",
action="store_true",
help="Combines backward pass and optimizer step to reduce VRAM usage / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。",
)


def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
Expand Down Expand Up @@ -3087,7 +3092,7 @@
)
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument(
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする"

Check warning on line 3095 in library/train_util.py

View workflow job for this annotation

GitHub Actions / build

"grandient" should be "gradient".
)
parser.add_argument(
"--gradient_accumulation_steps",
Expand Down Expand Up @@ -3846,6 +3851,14 @@
optimizer_type = "AdamW"
optimizer_type = optimizer_type.lower()

if args.fused_backward_pass:
assert (
optimizer_type == "Adafactor".lower()
), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します"
assert (
args.gradient_accumulation_steps == 1
), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません"

# 引数を分解する
optimizer_kwargs = {}
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
Expand Down
29 changes: 23 additions & 6 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,20 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

if args.fused_backward_pass:
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None

parameter.register_post_accumulate_grad_hook(__grad_hook)

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
Expand Down Expand Up @@ -619,13 +633,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
if not args.fused_backward_pass:
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()

lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

Expand Down
Loading