|
| 1 | +import os |
| 2 | +import sys |
| 3 | +import time |
| 4 | +import argparse |
| 5 | +import logging |
| 6 | + |
| 7 | +import tqdm |
| 8 | +import numpy as np |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | +import torch.nn.functional as F |
| 12 | +from torchvision.utils import save_image |
| 13 | +import tensorboardX |
| 14 | + |
| 15 | + |
| 16 | + |
| 17 | +from model import Generator, Discriminator |
| 18 | +from dataset import get_dataloader |
| 19 | +from utils import AverageMeter |
| 20 | + |
| 21 | +# References: https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py |
| 22 | + |
| 23 | +logging.basicConfig( |
| 24 | + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| 25 | + datefmt="%Y-%m-%d %H:%M:%S", |
| 26 | + level=os.environ.get("LOGLEVEL", "INFO").upper(), |
| 27 | + stream=sys.stdout, |
| 28 | +) |
| 29 | +logger = logging.getLogger(__name__) |
| 30 | +logger.setLevel(logging.INFO) |
| 31 | + |
| 32 | + |
| 33 | +def get_args(): |
| 34 | + parser = argparse.ArgumentParser() |
| 35 | + parser.add_argument("exp_tag", type=str, help="The tag of current experiment") |
| 36 | + parser.add_argument("--dset_dir", type=str, default="./data/mnist", help="where to load mnist dataset") |
| 37 | + parser.add_argument("--save_dir", type=str, default="./generated", help="where to save generated images") |
| 38 | + parser.add_argument("--log_dir", type=str, default="./.checkpoints", help="where to save tensorboard logs") |
| 39 | + parser.add_argument("--img_size", type=int, default=(1, 28, 28), help="size of each image dimension") |
| 40 | + |
| 41 | + parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") |
| 42 | + parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") |
| 43 | + parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") |
| 44 | + |
| 45 | + parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") |
| 46 | + parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training") |
| 47 | + |
| 48 | + parser.add_argument( |
| 49 | + "--update_g_per_iter", default=1, type=int, help="How many updates the generator performs in each iteration" |
| 50 | + ) |
| 51 | + parser.add_argument( |
| 52 | + "--update_d_per_iter", default=1, type=int, help="How many updates the discriminator performs in each iteration" |
| 53 | + ) |
| 54 | + # Notes: Typically, we don't stop updating the discriminator during training. Here we only add this option for |
| 55 | + # demonstration purpose. |
| 56 | + parser.add_argument( |
| 57 | + "--d_stop_update", default=int(1e10), type=int, |
| 58 | + help="Which epoch the discriminator stops to update" |
| 59 | + ) |
| 60 | + |
| 61 | + parser.add_argument("--loss_d_scale", type=float, default=2., help="The scaling factor of the discriminator's loss") |
| 62 | + parser.add_argument("--input_dim", type=int, default=100, help="dimensionality of the latent space") |
| 63 | + parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples") |
| 64 | + |
| 65 | + parser.add_argument("--use_cuda", action="store_true", help="Whether to use CUDA") |
| 66 | + |
| 67 | + args = parser.parse_args() |
| 68 | + for k in ("save_dir", "log_dir"): |
| 69 | + v = getattr(args, k) |
| 70 | + v = os.path.join(v, args.exp_tag) |
| 71 | + setattr(args, k, v) |
| 72 | + os.makedirs(os.path.expanduser(v), exist_ok=True) |
| 73 | + return args |
| 74 | + |
| 75 | + |
| 76 | +if __name__ == "__main__": |
| 77 | + # Notes: For reproducibility, we often fix the random seeds(e.g. torch, numpy, random) in the very beginning of training |
| 78 | + # Here we omit this step since we don't need any guarantee of reproducibility. |
| 79 | + # References: https://pytorch.org/docs/stable/notes/randomness.html |
| 80 | + args = get_args() |
| 81 | + logger.info(f"ARGS: {args}") |
| 82 | + if torch.cuda.is_available() and args.use_cuda: |
| 83 | + cuda = True |
| 84 | + logger.info("Using CUDA") |
| 85 | + else: |
| 86 | + cuda = False |
| 87 | + logger.info("Using CPU") |
| 88 | + |
| 89 | + # Define model, loss function(s), dataloader, optimizers and all other stuffs |
| 90 | + G = Generator(args.input_dim, args.img_size) |
| 91 | + D = Discriminator(args.img_size) |
| 92 | + ce = torch.nn.BCELoss() |
| 93 | + # In each iteration, we may update the discriminator for multiple times, so we directly load `update_d_per_iter` |
| 94 | + # batchs into memory, so we set the `batch_size` as `args.batch_size * args.update_d_per_iter` instead of |
| 95 | + # `args.batch_size` |
| 96 | + dataloader = get_dataloader( |
| 97 | + args.dset_dir, args.batch_size * args.update_d_per_iter, |
| 98 | + is_training=True, img_size=args.img_size[1:] |
| 99 | + ) |
| 100 | + if cuda: |
| 101 | + # Notes: if you have more than one GPUs, you can use DataParallel(DP) or DistributedDataParallel(DDP) to enable |
| 102 | + # parallelness among multiple devices. DP is easier to implement but may be slower than DDP. |
| 103 | + # In this script, we only utilize one GPU so we don't need any of them. |
| 104 | + # One thing to notice is that DDP uses multiple processes while DP uses only one process. |
| 105 | + # References: |
| 106 | + # https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html |
| 107 | + # https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html |
| 108 | + G.cuda() |
| 109 | + D.cuda() |
| 110 | + ce.cuda() |
| 111 | + optimizer_G = torch.optim.Adam(G.parameters(), lr=args.lr, betas=(args.b1, args.b2)) |
| 112 | + optimizer_D = torch.optim.Adam(D.parameters(), lr=args.lr, betas=(args.b1, args.b2)) |
| 113 | + d = os.path.join(args.log_dir, "log") |
| 114 | + os.makedirs(d, exist_ok=True) |
| 115 | + writer = tensorboardX.SummaryWriter(log_dir=d) |
| 116 | + |
| 117 | + # Start training |
| 118 | + tic = time.time() |
| 119 | + n_iter = 0 |
| 120 | + # p_fake_is_real: the probability that the discriminator thinks the fake images are real. |
| 121 | + recorders = { |
| 122 | + k: AverageMeter() for k in ('d', 'g', "p_real_is_real", "p_fake_is_real", "p_fake_is_fake") |
| 123 | + } |
| 124 | + d_stoped = False |
| 125 | + for epoch_idx in range(1, args.n_epochs + 1): # indexing from 1 instead of 0 |
| 126 | + if epoch_idx >= args.d_stop_update: |
| 127 | + logger.info(f"Epoch {epoch_idx}: Stop updating the discriminator.") |
| 128 | + d_stoped = True |
| 129 | + for n, p in D.named_parameters(): |
| 130 | + p.requires_grad_(False) |
| 131 | + pbar = tqdm.tqdm(range(len(dataloader)), desc="Training", disable=False) |
| 132 | + pbar.set_postfix({"epoch": f"{0}/{args.n_epochs}", "loss_d": 0., "loss_g": 0.}) |
| 133 | + |
| 134 | + def d_step(img): |
| 135 | + # Perform one step of discriminator |
| 136 | + loss_real_is_real = ce(D(img), real) |
| 137 | + loss_fake_is_fake = ce(D(fake_imgs.detach()), fake) |
| 138 | + loss_d = (loss_real_is_real + loss_fake_is_fake) / args.loss_d_scale |
| 139 | + return loss_d, loss_real_is_real, loss_fake_is_fake |
| 140 | + |
| 141 | + for i, (imgs, _) in enumerate(dataloader): |
| 142 | + B = imgs.size(0) // args.update_d_per_iter |
| 143 | + n_iter += 1 |
| 144 | + |
| 145 | + # Prepare labels for adversarial training |
| 146 | + real = torch.ones((B, 1), requires_grad=False) |
| 147 | + fake = torch.zeros((B, 1), requires_grad=False) |
| 148 | + if cuda: |
| 149 | + imgs, real, fake = [t.cuda() for t in (imgs, real, fake)] |
| 150 | + |
| 151 | + for j in range(args.update_g_per_iter): |
| 152 | + # Train generator |
| 153 | + noise = torch.randn((B, args.input_dim)) |
| 154 | + if cuda: |
| 155 | + noise = noise.cuda() |
| 156 | + fake_imgs = G(noise) |
| 157 | + loss_fake_is_real = ce(D(fake_imgs), real) |
| 158 | + loss_g = loss_fake_is_real |
| 159 | + optimizer_G.zero_grad() |
| 160 | + loss_g.backward() |
| 161 | + optimizer_G.step() |
| 162 | + recorders['g'].update(loss_g.item(), 1) |
| 163 | + recorders["p_fake_is_real"].update(np.exp(-loss_fake_is_real.item()), 1) |
| 164 | + |
| 165 | + assert imgs.size(0) == args.batch_size * args.update_d_per_iter |
| 166 | + for j in range(args.update_d_per_iter): |
| 167 | + # Train discriminator |
| 168 | + img = imgs[j * args.batch_size: (j + 1) * args.batch_size] |
| 169 | + if not d_stoped: |
| 170 | + loss_d, loss_real_is_real, loss_fake_is_fake = d_step(img) |
| 171 | + optimizer_D.zero_grad() |
| 172 | + loss_d.backward() |
| 173 | + optimizer_D.step() |
| 174 | + else: |
| 175 | + with torch.no_grad(): |
| 176 | + loss_d, loss_real_is_real, loss_fake_is_fake = d_step(img) |
| 177 | + recorders['d'].update(loss_d.item(), 1) |
| 178 | + recorders["p_real_is_real"].update(np.exp(-loss_real_is_real.item()), 1) |
| 179 | + recorders["p_fake_is_fake"].update(np.exp(-loss_fake_is_fake.item()), 1) |
| 180 | + |
| 181 | + state = { |
| 182 | + "epoch": f"{epoch_idx}/{args.n_epochs}", |
| 183 | + "loss_d": f"{recorders['d'].get():.4f}", |
| 184 | + "loss_g": f"{recorders['g'].get():.4f}", |
| 185 | + } |
| 186 | + pbar.set_postfix(state) |
| 187 | + pbar.update() |
| 188 | + writer.add_scalars("Training", {k: v.get() for k, v in recorders.items()}, n_iter) |
| 189 | + |
| 190 | + # Save images per `args.sample_interval` iterations |
| 191 | + batches_done = epoch_idx * len(dataloader) + i |
| 192 | + if batches_done % args.sample_interval == 0: |
| 193 | + save_image(fake_imgs.data[:25], f"{args.save_dir}/{batches_done:07d}.png", nrow=5, normalize=True) |
| 194 | + # pbar.reset() |
| 195 | + |
| 196 | + ckpt_ph = os.path.join( |
| 197 | + args.log_dir, |
| 198 | + "train", f"epoch{epoch_idx}_lossd_{recorders['d'].get():.3f}_lossg_{recorders['g'].get():.3f}.pt" |
| 199 | + ) |
| 200 | + os.makedirs(os.path.dirname(ckpt_ph), exist_ok=True) |
| 201 | + # Notes: if you're using DataParallel or DistributedDataParallel, you may prefer G.module.state_dict() and |
| 202 | + # D.module.state_dict() to unwrap G and D first. |
| 203 | + torch.save( |
| 204 | + { |
| 205 | + "G": G.state_dict(), |
| 206 | + "D": D.state_dict(), |
| 207 | + "optimizer_G": optimizer_G.state_dict(), |
| 208 | + "optimizer_D": optimizer_D.state_dict(), |
| 209 | + }, |
| 210 | + ckpt_ph |
| 211 | + ) |
| 212 | + |
| 213 | + logger.info(f"Training finished. Duration: {(time.time() - tic) / 3600:.2f}h") |
0 commit comments