Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DDP issues and Support DDP for all training scripts #448

Merged
merged 13 commits into from
May 3, 2023
5 changes: 4 additions & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)

# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)

# verify load/save model formats
if load_stable_diffusion_format:
Expand Down Expand Up @@ -228,6 +228,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# transform DDP after prepare
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
Expand Down
34 changes: 33 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Union,
)
from accelerate import Accelerator
import gc
import glob
import math
import os
Expand All @@ -30,6 +31,7 @@

from tqdm import tqdm
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer
Expand Down Expand Up @@ -2850,7 +2852,7 @@ def prepare_dtype(args: argparse.Namespace):
return weight_dtype, save_dtype


def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
Expand Down Expand Up @@ -2879,6 +2881,36 @@ def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
return text_encoder, vae, unet, load_stable_diffusion_format


def transform_DDP(text_encoder, unet, network=None):
# Transform text_encoder, unet and network from DistributedDataParallel
return (encoder.module if type(encoder) == DDP else encoder for encoder in [text_encoder, unet, network])


def load_target_model(args, weight_dtype, accelerator):
# load models for each process
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")

text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu"
)

# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)

gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()

text_encoder, unet, _ = transform_DDP(text_encoder, unet, network=None)

return text_encoder, vae, unet, load_stable_diffusion_format


def patch_accelerator_for_fp16_training(accelerator):
org_unscale_grads = accelerator.scaler._unscale_grads_

Expand Down
2 changes: 1 addition & 1 deletion networks/lora_interrogator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def interrogate(args):
print(f"loading SD model: {args.sd_model}")
args.pretrained_model_name_or_path = args.sd_model
args.vae = None
text_encoder, vae, unet, _ = train_util.load_target_model(args,weights_dtype, DEVICE)
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)

print(f"loading LoRA: {args.model}")
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
Expand Down
8 changes: 6 additions & 2 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)

# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)

# verify load/save model formats
if load_stable_diffusion_format:
Expand Down Expand Up @@ -196,6 +196,9 @@ def train(args):
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# transform DDP after prepare
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)

if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error

Expand Down Expand Up @@ -297,7 +300,8 @@ def train(args):
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

if args.v_parameterization:
# v-parameterization training
Expand Down
36 changes: 6 additions & 30 deletions train_network.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from torch.nn.parallel import DistributedDataParallel as DDP
import importlib
import argparse
import gc
Expand Down Expand Up @@ -144,24 +143,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)

# モデルを読み込む
for pi in range(accelerator.state.num_processes):
# TODO: modify other training scripts as well
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")

text_encoder, vae, unet, _ = train_util.load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu"
)

# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)

gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)

# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
Expand Down Expand Up @@ -279,6 +261,9 @@ def train(args):
else:
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)

# transform DDP after prepare (train_network here only)
text_encoder, unet, network = train_util.transform_DDP(text_encoder, unet, network)

unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
Expand All @@ -288,20 +273,11 @@ def train(args):
text_encoder.train()

# set top parameter requires_grad = True for gradient checkpointing works
if type(text_encoder) == DDP:
text_encoder.module.text_model.embeddings.requires_grad_(True)
else:
text_encoder.text_model.embeddings.requires_grad_(True)
text_encoder.text_model.embeddings.requires_grad_(True)
else:
unet.eval()
text_encoder.eval()

# support DistributedDataParallel
if type(text_encoder) == DDP:
text_encoder = text_encoder.module
unet = unet.module
network = network.module


network.prepare_grad_etc(text_encoder, unet)

if not cache_latents:
Expand Down
5 changes: 4 additions & 1 deletion train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)

# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)

# Convert the init_word to token_id
if args.init_word is not None:
Expand Down Expand Up @@ -280,6 +280,9 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler
)

# transform DDP after prepare
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)

index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
Expand Down
5 changes: 4 additions & 1 deletion train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)

# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)

# Convert the init_word to token_id
if args.init_word is not None:
Expand Down Expand Up @@ -314,6 +314,9 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler
)

# transform DDP after prepare
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)

index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
Expand Down