Skip to content

Commit 04db361

Browse files
pengxin99lanking520
authored andcommitted
Extending the DCGAN example implemented by gluon API to provide a more straight-forward evaluation on the generated image (apache#12790)
* add inception_score to metric dcgan model * Update README.md * add two pic * updata readme * updata * Update README.md * add license * refine1 * refine2 * refine3 * fix review comments * Update README.md * Update example/gluon/DCGAN/README.md * Update example/gluon/DCGAN/README.md * Update example/gluon/DCGAN/README.md * Update example/gluon/DCGAN/README.md * Update example/gluon/DCGAN/README.md * Update example/gluon/DCGAN/README.md * Update example/gluon/DCGAN/README.md * Update example/gluon/DCGAN/README.md * Update example/gluon/DCGAN/README.md * Update example/gluon/DCGAN/README.md * Update example/gluon/DCGAN/README.md * modify sn_gan file links to DCGAN * update pic links to web-data * update the pic path of readme.md * rm folder pic/, and related links update to https://github.com/dmlc/web-data/mxnet/example/gluon/DCGAN/ * Update README.md
1 parent 27ccc39 commit 04db361

File tree

9 files changed

+506
-240
lines changed

9 files changed

+506
-240
lines changed

example/gluon/DCGAN/README.md

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# DCGAN in MXNet
2+
3+
[Deep Convolutional Generative Adversarial Networks(DCGAN)](https://arxiv.org/abs/1511.06434) implementation with Apache MXNet GLUON.
4+
This implementation uses [inception_score](https://github.com/openai/improved-gan) to evaluate the model.
5+
6+
You can use this reference implementation on the MNIST and CIFAR-10 datasets.
7+
8+
9+
#### Generated image output examples from the CIFAR-10 dataset
10+
![Generated image output examples from the CIFAR-10 dataset](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gluon/DCGAN/fake_img_iter_13900.png)
11+
12+
#### Generated image output examples from the MNIST dataset
13+
![Generated image output examples from the MNIST dataset](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gluon/DCGAN/fake_img_iter_21700.png)
14+
15+
#### inception_score in cpu and gpu (the real image`s score is around 3.3)
16+
CPU & GPU
17+
18+
![inception score with CPU](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gluon/DCGAN/inception_score_cifar10_cpu.png)
19+
![inception score with GPU](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gluon/DCGAN/inception_score_cifar10.png)
20+
21+
## Quick start
22+
Use the following code to see the configurations you can set:
23+
```bash
24+
python dcgan.py -h
25+
```
26+
27+
28+
optional arguments:
29+
-h, --help show this help message and exit
30+
--dataset DATASET dataset to use. options are cifar10 and mnist.
31+
--batch-size BATCH_SIZE input batch size, default is 64
32+
--nz NZ size of the latent z vector, default is 100
33+
--ngf NGF the channel of each generator filter layer, default is 64.
34+
--ndf NDF the channel of each descriminator filter layer, default is 64.
35+
--nepoch NEPOCH number of epochs to train for, default is 25.
36+
--niter NITER save generated images and inception_score per niter iters, default is 100.
37+
--lr LR learning rate, default=0.0002
38+
--beta1 BETA1 beta1 for adam. default=0.5
39+
--cuda enables cuda
40+
--netG NETG path to netG (to continue training)
41+
--netD NETD path to netD (to continue training)
42+
--outf OUTF folder to output images and model checkpoints
43+
--check-point CHECK_POINT
44+
save results at each epoch or not
45+
--inception_score INCEPTION_SCORE
46+
To record the inception_score, default is True.
47+
48+
49+
Use the following Python script to train a DCGAN model with default configurations using the CIFAR-10 dataset and record metrics with `inception_score`:
50+
```bash
51+
python dcgan.py
52+
```

example/gluon/DCGAN/__init__.py

Whitespace-only changes.

example/gluon/DCGAN/dcgan.py

+340
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import matplotlib as mpl
19+
mpl.use('Agg')
20+
from matplotlib import pyplot as plt
21+
22+
import argparse
23+
import mxnet as mx
24+
from mxnet import gluon
25+
from mxnet.gluon import nn
26+
from mxnet import autograd
27+
import numpy as np
28+
import logging
29+
from datetime import datetime
30+
import os
31+
import time
32+
33+
from inception_score import get_inception_score
34+
35+
36+
def fill_buf(buf, i, img, shape):
37+
"""
38+
Reposition the images generated by the generator so that it can be saved as picture matrix.
39+
:param buf: the images metric
40+
:param i: index of each image
41+
:param img: images generated by generator once
42+
:param shape: each image`s shape
43+
:return: Adjust images for output
44+
"""
45+
n = buf.shape[0]//shape[1]
46+
m = buf.shape[1]//shape[0]
47+
48+
sx = (i%m)*shape[0]
49+
sy = (i//m)*shape[1]
50+
buf[sy:sy+shape[1], sx:sx+shape[0], :] = img
51+
return None
52+
53+
54+
def visual(title, X, name):
55+
"""
56+
Image visualization and preservation
57+
:param title: title
58+
:param X: images to visualized
59+
:param name: saved picture`s name
60+
:return:
61+
"""
62+
assert len(X.shape) == 4
63+
X = X.transpose((0, 2, 3, 1))
64+
X = np.clip((X - np.min(X))*(255.0/(np.max(X) - np.min(X))), 0, 255).astype(np.uint8)
65+
n = np.ceil(np.sqrt(X.shape[0]))
66+
buff = np.zeros((int(n*X.shape[1]), int(n*X.shape[2]), int(X.shape[3])), dtype=np.uint8)
67+
for i, img in enumerate(X):
68+
fill_buf(buff, i, img, X.shape[1:3])
69+
buff = buff[:, :, ::-1]
70+
plt.imshow(buff)
71+
plt.title(title)
72+
plt.savefig(name)
73+
74+
75+
parser = argparse.ArgumentParser()
76+
parser = argparse.ArgumentParser(description='Train a DCgan model for image generation '
77+
'and then use inception_score to metric the result.')
78+
parser.add_argument('--dataset', type=str, default='cifar10', help='dataset to use. options are cifar10 and mnist.')
79+
parser.add_argument('--batch-size', type=int, default=64, help='input batch size, default is 64')
80+
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector, default is 100')
81+
parser.add_argument('--ngf', type=int, default=64, help='the channel of each generator filter layer, default is 64.')
82+
parser.add_argument('--ndf', type=int, default=64, help='the channel of each descriminator filter layer, default is 64.')
83+
parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for, default is 25.')
84+
parser.add_argument('--niter', type=int, default=10, help='save generated images and inception_score per niter iters, default is 100.')
85+
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
86+
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
87+
parser.add_argument('--cuda', action='store_true', help='enables cuda')
88+
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
89+
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
90+
parser.add_argument('--outf', default='./results', help='folder to output images and model checkpoints')
91+
parser.add_argument('--check-point', default=True, help="save results at each epoch or not")
92+
parser.add_argument('--inception_score', type=bool, default=True, help='To record the inception_score, default is True.')
93+
94+
opt = parser.parse_args()
95+
print(opt)
96+
97+
logging.basicConfig(level=logging.DEBUG)
98+
99+
nz = int(opt.nz)
100+
ngf = int(opt.ngf)
101+
ndf = int(opt.ndf)
102+
niter = opt.niter
103+
nc = 3
104+
if opt.cuda:
105+
ctx = mx.gpu(0)
106+
else:
107+
ctx = mx.cpu()
108+
batch_size = opt.batch_size
109+
check_point = bool(opt.check_point)
110+
outf = opt.outf
111+
dataset = opt.dataset
112+
113+
if not os.path.exists(outf):
114+
os.makedirs(outf)
115+
116+
117+
def transformer(data, label):
118+
# resize to 64x64
119+
data = mx.image.imresize(data, 64, 64)
120+
# transpose from (64, 64, 3) to (3, 64, 64)
121+
data = mx.nd.transpose(data, (2, 0, 1))
122+
# normalize to [-1, 1]
123+
data = data.astype(np.float32)/128 - 1
124+
# if image is greyscale, repeat 3 times to get RGB image.
125+
if data.shape[0] == 1:
126+
data = mx.nd.tile(data, (3, 1, 1))
127+
return data, label
128+
129+
130+
# get dataset with the batch_size num each time
131+
def get_dataset(dataset):
132+
# mnist
133+
if dataset == "mnist":
134+
train_data = gluon.data.DataLoader(
135+
gluon.data.vision.MNIST('./data', train=True, transform=transformer),
136+
batch_size, shuffle=True, last_batch='discard')
137+
138+
val_data = gluon.data.DataLoader(
139+
gluon.data.vision.MNIST('./data', train=False, transform=transformer),
140+
batch_size, shuffle=False)
141+
# cifar10
142+
elif dataset == "cifar10":
143+
train_data = gluon.data.DataLoader(
144+
gluon.data.vision.CIFAR10('./data', train=True, transform=transformer),
145+
batch_size, shuffle=True, last_batch='discard')
146+
147+
val_data = gluon.data.DataLoader(
148+
gluon.data.vision.CIFAR10('./data', train=False, transform=transformer),
149+
batch_size, shuffle=False)
150+
151+
return train_data, val_data
152+
153+
154+
def get_netG():
155+
# build the generator
156+
netG = nn.Sequential()
157+
with netG.name_scope():
158+
# input is Z, going into a convolution
159+
netG.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False))
160+
netG.add(nn.BatchNorm())
161+
netG.add(nn.Activation('relu'))
162+
# state size. (ngf*8) x 4 x 4
163+
netG.add(nn.Conv2DTranspose(ngf * 4, 4, 2, 1, use_bias=False))
164+
netG.add(nn.BatchNorm())
165+
netG.add(nn.Activation('relu'))
166+
# state size. (ngf*4) x 8 x 8
167+
netG.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False))
168+
netG.add(nn.BatchNorm())
169+
netG.add(nn.Activation('relu'))
170+
# state size. (ngf*2) x 16 x 16
171+
netG.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False))
172+
netG.add(nn.BatchNorm())
173+
netG.add(nn.Activation('relu'))
174+
# state size. (ngf) x 32 x 32
175+
netG.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False))
176+
netG.add(nn.Activation('tanh'))
177+
# state size. (nc) x 64 x 64
178+
179+
return netG
180+
181+
182+
def get_netD():
183+
# build the discriminator
184+
netD = nn.Sequential()
185+
with netD.name_scope():
186+
# input is (nc) x 64 x 64
187+
netD.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False))
188+
netD.add(nn.LeakyReLU(0.2))
189+
# state size. (ndf) x 32 x 32
190+
netD.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False))
191+
netD.add(nn.BatchNorm())
192+
netD.add(nn.LeakyReLU(0.2))
193+
# state size. (ndf*2) x 16 x 16
194+
netD.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False))
195+
netD.add(nn.BatchNorm())
196+
netD.add(nn.LeakyReLU(0.2))
197+
# state size. (ndf*4) x 8 x 8
198+
netD.add(nn.Conv2D(ndf * 8, 4, 2, 1, use_bias=False))
199+
netD.add(nn.BatchNorm())
200+
netD.add(nn.LeakyReLU(0.2))
201+
# state size. (ndf*8) x 4 x 4
202+
netD.add(nn.Conv2D(2, 4, 1, 0, use_bias=False))
203+
# state size. 2 x 1 x 1
204+
205+
return netD
206+
207+
208+
def get_configurations(netG, netD):
209+
# loss
210+
loss = gluon.loss.SoftmaxCrossEntropyLoss()
211+
212+
# initialize the generator and the discriminator
213+
netG.initialize(mx.init.Normal(0.02), ctx=ctx)
214+
netD.initialize(mx.init.Normal(0.02), ctx=ctx)
215+
216+
# trainer for the generator and the discriminator
217+
trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': opt.lr, 'beta1': opt.beta1})
218+
trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': opt.lr, 'beta1': opt.beta1})
219+
220+
return loss, trainerG, trainerD
221+
222+
223+
def ins_save(inception_score):
224+
# draw the inception_score curve
225+
length = len(inception_score)
226+
x = np.arange(0, length)
227+
plt.figure(figsize=(8.0, 6.0))
228+
plt.plot(x, inception_score)
229+
plt.xlabel("iter/100")
230+
plt.ylabel("inception_score")
231+
plt.savefig("inception_score.png")
232+
233+
234+
# main function
235+
def main():
236+
print("|------- new changes!!!!!!!!!")
237+
# to get the dataset and net configuration
238+
train_data, val_data = get_dataset(dataset)
239+
netG = get_netG()
240+
netD = get_netD()
241+
loss, trainerG, trainerD = get_configurations(netG, netD)
242+
243+
# set labels
244+
real_label = mx.nd.ones((opt.batch_size,), ctx=ctx)
245+
fake_label = mx.nd.zeros((opt.batch_size,), ctx=ctx)
246+
247+
metric = mx.metric.Accuracy()
248+
print('Training... ')
249+
stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
250+
251+
iter = 0
252+
253+
# to metric the network
254+
loss_d = []
255+
loss_g = []
256+
inception_score = []
257+
258+
for epoch in range(opt.nepoch):
259+
tic = time.time()
260+
btic = time.time()
261+
for data, _ in train_data:
262+
############################
263+
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
264+
###########################
265+
# train with real_t
266+
data = data.as_in_context(ctx)
267+
noise = mx.nd.random.normal(0, 1, shape=(opt.batch_size, nz, 1, 1), ctx=ctx)
268+
269+
with autograd.record():
270+
output = netD(data)
271+
# reshape output from (opt.batch_size, 2, 1, 1) to (opt.batch_size, 2)
272+
output = output.reshape((opt.batch_size, 2))
273+
errD_real = loss(output, real_label)
274+
275+
metric.update([real_label, ], [output, ])
276+
277+
with autograd.record():
278+
fake = netG(noise)
279+
output = netD(fake.detach())
280+
output = output.reshape((opt.batch_size, 2))
281+
errD_fake = loss(output, fake_label)
282+
errD = errD_real + errD_fake
283+
284+
errD.backward()
285+
metric.update([fake_label,], [output,])
286+
287+
trainerD.step(opt.batch_size)
288+
289+
############################
290+
# (2) Update G network: maximize log(D(G(z)))
291+
###########################
292+
with autograd.record():
293+
output = netD(fake)
294+
output = output.reshape((-1, 2))
295+
errG = loss(output, real_label)
296+
297+
errG.backward()
298+
299+
trainerG.step(opt.batch_size)
300+
301+
name, acc = metric.get()
302+
logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d'
303+
% (mx.nd.mean(errD).asscalar(), mx.nd.mean(errG).asscalar(), acc, iter, epoch))
304+
if iter % niter == 0:
305+
visual('gout', fake.asnumpy(), name=os.path.join(outf, 'fake_img_iter_%d.png' % iter))
306+
visual('data', data.asnumpy(), name=os.path.join(outf, 'real_img_iter_%d.png' % iter))
307+
# record the metric data
308+
loss_d.append(errD)
309+
loss_g.append(errG)
310+
if opt.inception_score:
311+
score, _ = get_inception_score(fake)
312+
inception_score.append(score)
313+
314+
iter = iter + 1
315+
btic = time.time()
316+
317+
name, acc = metric.get()
318+
metric.reset()
319+
logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
320+
logging.info('time: %f' % (time.time() - tic))
321+
322+
# save check_point
323+
if check_point:
324+
netG.save_parameters(os.path.join(outf,'generator_epoch_%d.params' %epoch))
325+
netD.save_parameters(os.path.join(outf,'discriminator_epoch_%d.params' % epoch))
326+
327+
# save parameter
328+
netG.save_parameters(os.path.join(outf, 'generator.params'))
329+
netD.save_parameters(os.path.join(outf, 'discriminator.params'))
330+
331+
# visualization the inception_score as a picture
332+
if opt.inception_score:
333+
ins_save(inception_score)
334+
335+
336+
if __name__ == '__main__':
337+
if opt.inception_score:
338+
print("Use inception_score to metric this DCgan model, the reusult is save as a picture named \"inception_score.png\"!")
339+
main()
340+

0 commit comments

Comments
 (0)