Skip to content

Commit 7ee8716

Browse files
authored
Add files via upload
1 parent efefa5e commit 7ee8716

File tree

5 files changed

+385
-0
lines changed

5 files changed

+385
-0
lines changed

dataset.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import os
2+
import sys
3+
import logging
4+
5+
import torch
6+
import torchvision.transforms as transforms
7+
from torch.utils.data import DataLoader
8+
from torchvision import datasets
9+
10+
logging.basicConfig(
11+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
12+
datefmt="%Y-%m-%d %H:%M:%S",
13+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
14+
stream=sys.stdout,
15+
)
16+
logger = logging.getLogger(__name__)
17+
logger.setLevel(logging.INFO)
18+
19+
20+
def get_dataloader(dset_dir, batch_size, is_training, img_size, dset_name="MNIST"):
21+
assert batch_size > 1
22+
dset_cls = getattr(datasets, dset_name)
23+
dset = dset_cls(
24+
dset_dir,
25+
train=is_training,
26+
download=True,
27+
transform=transforms.Compose(
28+
[transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
29+
),
30+
)
31+
dataloader = torch.utils.data.DataLoader(
32+
dataset=dset,
33+
batch_size=batch_size,
34+
shuffle=True,
35+
drop_last=is_training,
36+
)
37+
logger.info(
38+
"Loading {} dataset from directory: {}, "
39+
"batch_size: {}, "
40+
"img_size: {}, "
41+
"is_training: {}.".format(
42+
dset_name, dset_dir, batch_size, img_size, is_training
43+
)
44+
)
45+
return dataloader

model.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import os
2+
import sys
3+
import logging
4+
5+
import numpy as np
6+
import torch
7+
import torch.nn as nn
8+
import torch.nn.functional as F
9+
10+
11+
logging.basicConfig(
12+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
13+
datefmt="%Y-%m-%d %H:%M:%S",
14+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
15+
stream=sys.stdout,
16+
)
17+
logger = logging.getLogger(__name__)
18+
logger.setLevel(logging.INFO)
19+
20+
21+
class Generator(nn.Module):
22+
def __init__(self, idim: int, img_size: tuple):
23+
"""
24+
References: https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py
25+
26+
Args:
27+
idim: the hidden dim of generator. It should be the same as the `args.input_dim` in `train.py`
28+
img_size: the size of expected images
29+
"""
30+
super(Generator, self).__init__()
31+
32+
def block(in_feat, out_feat, normalize=True):
33+
layers = [nn.Linear(in_feat, out_feat)]
34+
if normalize:
35+
layers.append(nn.BatchNorm1d(out_feat, 0.8))
36+
layers.append(nn.LeakyReLU(0.2, inplace=True))
37+
return nn.Sequential(*layers)
38+
39+
self.idim = idim
40+
self.odim = np.prod(img_size)
41+
self.img_size = img_size
42+
43+
self.input_emb = block(idim, 128, normalize=False)
44+
self.convs = nn.ModuleList([
45+
block(128, 256),
46+
block(256, 512),
47+
block(512, 1024),
48+
])
49+
self.lin = nn.Linear(1024, self.odim)
50+
self.tanh = nn.Tanh()
51+
52+
logger.info(self)
53+
54+
def forward(self, x):
55+
out = self.input_emb(x)
56+
for conv in self.convs:
57+
out = conv(out)
58+
out = self.tanh(self.lin(out))
59+
out = out.view(out.size(0), *self.img_size)
60+
return out
61+
62+
63+
class Discriminator(nn.Module):
64+
def __init__(self, img_size):
65+
"""
66+
References: https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py
67+
68+
Args:
69+
img_size: the size of input images
70+
"""
71+
super(Discriminator, self).__init__()
72+
73+
self.img_size = img_size
74+
self.idim = np.prod(img_size)
75+
76+
self.model = nn.Sequential(
77+
nn.Linear(self.idim, 512),
78+
nn.LeakyReLU(0.2, inplace=True),
79+
nn.Linear(512, 256),
80+
nn.LeakyReLU(0.2, inplace=True),
81+
nn.Linear(256, 1),
82+
nn.Sigmoid(),
83+
)
84+
85+
logger.info(self)
86+
87+
def forward(self, x):
88+
x = x.view(x.size(0), -1)
89+
p = self.model(x)
90+
return p

run.sh

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python3 ./train.py baseline

train.py

+213
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import os
2+
import sys
3+
import time
4+
import argparse
5+
import logging
6+
7+
import tqdm
8+
import numpy as np
9+
import torch
10+
import torch.nn as nn
11+
import torch.nn.functional as F
12+
from torchvision.utils import save_image
13+
import tensorboardX
14+
15+
16+
17+
from model import Generator, Discriminator
18+
from dataset import get_dataloader
19+
from utils import AverageMeter
20+
21+
# References: https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py
22+
23+
logging.basicConfig(
24+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
25+
datefmt="%Y-%m-%d %H:%M:%S",
26+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
27+
stream=sys.stdout,
28+
)
29+
logger = logging.getLogger(__name__)
30+
logger.setLevel(logging.INFO)
31+
32+
33+
def get_args():
34+
parser = argparse.ArgumentParser()
35+
parser.add_argument("exp_tag", type=str, help="The tag of current experiment")
36+
parser.add_argument("--dset_dir", type=str, default="./data/mnist", help="where to load mnist dataset")
37+
parser.add_argument("--save_dir", type=str, default="./generated", help="where to save generated images")
38+
parser.add_argument("--log_dir", type=str, default="./.checkpoints", help="where to save tensorboard logs")
39+
parser.add_argument("--img_size", type=int, default=(1, 28, 28), help="size of each image dimension")
40+
41+
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
42+
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
43+
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
44+
45+
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
46+
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
47+
48+
parser.add_argument(
49+
"--update_g_per_iter", default=1, type=int, help="How many updates the generator performs in each iteration"
50+
)
51+
parser.add_argument(
52+
"--update_d_per_iter", default=1, type=int, help="How many updates the discriminator performs in each iteration"
53+
)
54+
# Notes: Typically, we don't stop updating the discriminator during training. Here we only add this option for
55+
# demonstration purpose.
56+
parser.add_argument(
57+
"--d_stop_update", default=int(1e10), type=int,
58+
help="Which epoch the discriminator stops to update"
59+
)
60+
61+
parser.add_argument("--loss_d_scale", type=float, default=2., help="The scaling factor of the discriminator's loss")
62+
parser.add_argument("--input_dim", type=int, default=100, help="dimensionality of the latent space")
63+
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
64+
65+
parser.add_argument("--use_cuda", action="store_true", help="Whether to use CUDA")
66+
67+
args = parser.parse_args()
68+
for k in ("save_dir", "log_dir"):
69+
v = getattr(args, k)
70+
v = os.path.join(v, args.exp_tag)
71+
setattr(args, k, v)
72+
os.makedirs(os.path.expanduser(v), exist_ok=True)
73+
return args
74+
75+
76+
if __name__ == "__main__":
77+
# Notes: For reproducibility, we often fix the random seeds(e.g. torch, numpy, random) in the very beginning of training
78+
# Here we omit this step since we don't need any guarantee of reproducibility.
79+
# References: https://pytorch.org/docs/stable/notes/randomness.html
80+
args = get_args()
81+
logger.info(f"ARGS: {args}")
82+
if torch.cuda.is_available() and args.use_cuda:
83+
cuda = True
84+
logger.info("Using CUDA")
85+
else:
86+
cuda = False
87+
logger.info("Using CPU")
88+
89+
# Define model, loss function(s), dataloader, optimizers and all other stuffs
90+
G = Generator(args.input_dim, args.img_size)
91+
D = Discriminator(args.img_size)
92+
ce = torch.nn.BCELoss()
93+
# In each iteration, we may update the discriminator for multiple times, so we directly load `update_d_per_iter`
94+
# batchs into memory, so we set the `batch_size` as `args.batch_size * args.update_d_per_iter` instead of
95+
# `args.batch_size`
96+
dataloader = get_dataloader(
97+
args.dset_dir, args.batch_size * args.update_d_per_iter,
98+
is_training=True, img_size=args.img_size[1:]
99+
)
100+
if cuda:
101+
# Notes: if you have more than one GPUs, you can use DataParallel(DP) or DistributedDataParallel(DDP) to enable
102+
# parallelness among multiple devices. DP is easier to implement but may be slower than DDP.
103+
# In this script, we only utilize one GPU so we don't need any of them.
104+
# One thing to notice is that DDP uses multiple processes while DP uses only one process.
105+
# References:
106+
# https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
107+
# https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
108+
G.cuda()
109+
D.cuda()
110+
ce.cuda()
111+
optimizer_G = torch.optim.Adam(G.parameters(), lr=args.lr, betas=(args.b1, args.b2))
112+
optimizer_D = torch.optim.Adam(D.parameters(), lr=args.lr, betas=(args.b1, args.b2))
113+
d = os.path.join(args.log_dir, "log")
114+
os.makedirs(d, exist_ok=True)
115+
writer = tensorboardX.SummaryWriter(log_dir=d)
116+
117+
# Start training
118+
tic = time.time()
119+
n_iter = 0
120+
# p_fake_is_real: the probability that the discriminator thinks the fake images are real.
121+
recorders = {
122+
k: AverageMeter() for k in ('d', 'g', "p_real_is_real", "p_fake_is_real", "p_fake_is_fake")
123+
}
124+
d_stoped = False
125+
for epoch_idx in range(1, args.n_epochs + 1): # indexing from 1 instead of 0
126+
if epoch_idx >= args.d_stop_update:
127+
logger.info(f"Epoch {epoch_idx}: Stop updating the discriminator.")
128+
d_stoped = True
129+
for n, p in D.named_parameters():
130+
p.requires_grad_(False)
131+
pbar = tqdm.tqdm(range(len(dataloader)), desc="Training", disable=False)
132+
pbar.set_postfix({"epoch": f"{0}/{args.n_epochs}", "loss_d": 0., "loss_g": 0.})
133+
134+
def d_step(img):
135+
# Perform one step of discriminator
136+
loss_real_is_real = ce(D(img), real)
137+
loss_fake_is_fake = ce(D(fake_imgs.detach()), fake)
138+
loss_d = (loss_real_is_real + loss_fake_is_fake) / args.loss_d_scale
139+
return loss_d, loss_real_is_real, loss_fake_is_fake
140+
141+
for i, (imgs, _) in enumerate(dataloader):
142+
B = imgs.size(0) // args.update_d_per_iter
143+
n_iter += 1
144+
145+
# Prepare labels for adversarial training
146+
real = torch.ones((B, 1), requires_grad=False)
147+
fake = torch.zeros((B, 1), requires_grad=False)
148+
if cuda:
149+
imgs, real, fake = [t.cuda() for t in (imgs, real, fake)]
150+
151+
for j in range(args.update_g_per_iter):
152+
# Train generator
153+
noise = torch.randn((B, args.input_dim))
154+
if cuda:
155+
noise = noise.cuda()
156+
fake_imgs = G(noise)
157+
loss_fake_is_real = ce(D(fake_imgs), real)
158+
loss_g = loss_fake_is_real
159+
optimizer_G.zero_grad()
160+
loss_g.backward()
161+
optimizer_G.step()
162+
recorders['g'].update(loss_g.item(), 1)
163+
recorders["p_fake_is_real"].update(np.exp(-loss_fake_is_real.item()), 1)
164+
165+
assert imgs.size(0) == args.batch_size * args.update_d_per_iter
166+
for j in range(args.update_d_per_iter):
167+
# Train discriminator
168+
img = imgs[j * args.batch_size: (j + 1) * args.batch_size]
169+
if not d_stoped:
170+
loss_d, loss_real_is_real, loss_fake_is_fake = d_step(img)
171+
optimizer_D.zero_grad()
172+
loss_d.backward()
173+
optimizer_D.step()
174+
else:
175+
with torch.no_grad():
176+
loss_d, loss_real_is_real, loss_fake_is_fake = d_step(img)
177+
recorders['d'].update(loss_d.item(), 1)
178+
recorders["p_real_is_real"].update(np.exp(-loss_real_is_real.item()), 1)
179+
recorders["p_fake_is_fake"].update(np.exp(-loss_fake_is_fake.item()), 1)
180+
181+
state = {
182+
"epoch": f"{epoch_idx}/{args.n_epochs}",
183+
"loss_d": f"{recorders['d'].get():.4f}",
184+
"loss_g": f"{recorders['g'].get():.4f}",
185+
}
186+
pbar.set_postfix(state)
187+
pbar.update()
188+
writer.add_scalars("Training", {k: v.get() for k, v in recorders.items()}, n_iter)
189+
190+
# Save images per `args.sample_interval` iterations
191+
batches_done = epoch_idx * len(dataloader) + i
192+
if batches_done % args.sample_interval == 0:
193+
save_image(fake_imgs.data[:25], f"{args.save_dir}/{batches_done:07d}.png", nrow=5, normalize=True)
194+
# pbar.reset()
195+
196+
ckpt_ph = os.path.join(
197+
args.log_dir,
198+
"train", f"epoch{epoch_idx}_lossd_{recorders['d'].get():.3f}_lossg_{recorders['g'].get():.3f}.pt"
199+
)
200+
os.makedirs(os.path.dirname(ckpt_ph), exist_ok=True)
201+
# Notes: if you're using DataParallel or DistributedDataParallel, you may prefer G.module.state_dict() and
202+
# D.module.state_dict() to unwrap G and D first.
203+
torch.save(
204+
{
205+
"G": G.state_dict(),
206+
"D": D.state_dict(),
207+
"optimizer_G": optimizer_G.state_dict(),
208+
"optimizer_D": optimizer_D.state_dict(),
209+
},
210+
ckpt_ph
211+
)
212+
213+
logger.info(f"Training finished. Duration: {(time.time() - tic) / 3600:.2f}h")

utils.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from typing import Union
2+
3+
4+
class AverageMeter(object):
5+
def __init__(self, init_c: float = 0., init_n: int = 0):
6+
self.c = init_c
7+
self.n = init_n
8+
self._recent = 0. # the most recent value
9+
10+
def reset(self):
11+
self.c, self.n, self._recent = 0., 0, 0.
12+
13+
def set(self, c: Union[list, float], n: Union[list, float]):
14+
# assert check_argument_types()
15+
if isinstance(c, list):
16+
assert len(c) == n
17+
self.c += sum(c)
18+
self.n += n
19+
self._recent = sum(c) / n
20+
else:
21+
self.c += c
22+
self.n += n
23+
self._recent = c / n
24+
25+
def update(self, *args, **kwargs):
26+
return self.set(*args, **kwargs)
27+
28+
def get(self):
29+
return self.c / self.n if self.n != 0 else 0.
30+
31+
@property
32+
def recent(self):
33+
return self._recent
34+
35+
def __repr__(self):
36+
return self.get()

0 commit comments

Comments
 (0)