|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | + |
| 18 | +import matplotlib as mpl |
| 19 | +mpl.use('Agg') |
| 20 | +from matplotlib import pyplot as plt |
| 21 | + |
| 22 | +import argparse |
| 23 | +import mxnet as mx |
| 24 | +from mxnet import gluon |
| 25 | +from mxnet.gluon import nn |
| 26 | +from mxnet import autograd |
| 27 | +import numpy as np |
| 28 | +import logging |
| 29 | +from datetime import datetime |
| 30 | +import os |
| 31 | +import time |
| 32 | + |
| 33 | +from inception_score import get_inception_score |
| 34 | + |
| 35 | + |
| 36 | +def fill_buf(buf, i, img, shape): |
| 37 | + """ |
| 38 | + Reposition the images generated by the generator so that it can be saved as picture matrix. |
| 39 | + :param buf: the images metric |
| 40 | + :param i: index of each image |
| 41 | + :param img: images generated by generator once |
| 42 | + :param shape: each image`s shape |
| 43 | + :return: Adjust images for output |
| 44 | + """ |
| 45 | + n = buf.shape[0]//shape[1] |
| 46 | + m = buf.shape[1]//shape[0] |
| 47 | + |
| 48 | + sx = (i%m)*shape[0] |
| 49 | + sy = (i//m)*shape[1] |
| 50 | + buf[sy:sy+shape[1], sx:sx+shape[0], :] = img |
| 51 | + return None |
| 52 | + |
| 53 | + |
| 54 | +def visual(title, X, name): |
| 55 | + """ |
| 56 | + Image visualization and preservation |
| 57 | + :param title: title |
| 58 | + :param X: images to visualized |
| 59 | + :param name: saved picture`s name |
| 60 | + :return: |
| 61 | + """ |
| 62 | + assert len(X.shape) == 4 |
| 63 | + X = X.transpose((0, 2, 3, 1)) |
| 64 | + X = np.clip((X - np.min(X))*(255.0/(np.max(X) - np.min(X))), 0, 255).astype(np.uint8) |
| 65 | + n = np.ceil(np.sqrt(X.shape[0])) |
| 66 | + buff = np.zeros((int(n*X.shape[1]), int(n*X.shape[2]), int(X.shape[3])), dtype=np.uint8) |
| 67 | + for i, img in enumerate(X): |
| 68 | + fill_buf(buff, i, img, X.shape[1:3]) |
| 69 | + buff = buff[:, :, ::-1] |
| 70 | + plt.imshow(buff) |
| 71 | + plt.title(title) |
| 72 | + plt.savefig(name) |
| 73 | + |
| 74 | + |
| 75 | +parser = argparse.ArgumentParser() |
| 76 | +parser = argparse.ArgumentParser(description='Train a DCgan model for image generation ' |
| 77 | + 'and then use inception_score to metric the result.') |
| 78 | +parser.add_argument('--dataset', type=str, default='cifar10', help='dataset to use. options are cifar10 and mnist.') |
| 79 | +parser.add_argument('--batch-size', type=int, default=64, help='input batch size, default is 64') |
| 80 | +parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector, default is 100') |
| 81 | +parser.add_argument('--ngf', type=int, default=64, help='the channel of each generator filter layer, default is 64.') |
| 82 | +parser.add_argument('--ndf', type=int, default=64, help='the channel of each descriminator filter layer, default is 64.') |
| 83 | +parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for, default is 25.') |
| 84 | +parser.add_argument('--niter', type=int, default=10, help='save generated images and inception_score per niter iters, default is 100.') |
| 85 | +parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') |
| 86 | +parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') |
| 87 | +parser.add_argument('--cuda', action='store_true', help='enables cuda') |
| 88 | +parser.add_argument('--netG', default='', help="path to netG (to continue training)") |
| 89 | +parser.add_argument('--netD', default='', help="path to netD (to continue training)") |
| 90 | +parser.add_argument('--outf', default='./results', help='folder to output images and model checkpoints') |
| 91 | +parser.add_argument('--check-point', default=True, help="save results at each epoch or not") |
| 92 | +parser.add_argument('--inception_score', type=bool, default=True, help='To record the inception_score, default is True.') |
| 93 | + |
| 94 | +opt = parser.parse_args() |
| 95 | +print(opt) |
| 96 | + |
| 97 | +logging.basicConfig(level=logging.DEBUG) |
| 98 | + |
| 99 | +nz = int(opt.nz) |
| 100 | +ngf = int(opt.ngf) |
| 101 | +ndf = int(opt.ndf) |
| 102 | +niter = opt.niter |
| 103 | +nc = 3 |
| 104 | +if opt.cuda: |
| 105 | + ctx = mx.gpu(0) |
| 106 | +else: |
| 107 | + ctx = mx.cpu() |
| 108 | +batch_size = opt.batch_size |
| 109 | +check_point = bool(opt.check_point) |
| 110 | +outf = opt.outf |
| 111 | +dataset = opt.dataset |
| 112 | + |
| 113 | +if not os.path.exists(outf): |
| 114 | + os.makedirs(outf) |
| 115 | + |
| 116 | + |
| 117 | +def transformer(data, label): |
| 118 | + # resize to 64x64 |
| 119 | + data = mx.image.imresize(data, 64, 64) |
| 120 | + # transpose from (64, 64, 3) to (3, 64, 64) |
| 121 | + data = mx.nd.transpose(data, (2, 0, 1)) |
| 122 | + # normalize to [-1, 1] |
| 123 | + data = data.astype(np.float32)/128 - 1 |
| 124 | + # if image is greyscale, repeat 3 times to get RGB image. |
| 125 | + if data.shape[0] == 1: |
| 126 | + data = mx.nd.tile(data, (3, 1, 1)) |
| 127 | + return data, label |
| 128 | + |
| 129 | + |
| 130 | +# get dataset with the batch_size num each time |
| 131 | +def get_dataset(dataset): |
| 132 | + # mnist |
| 133 | + if dataset == "mnist": |
| 134 | + train_data = gluon.data.DataLoader( |
| 135 | + gluon.data.vision.MNIST('./data', train=True, transform=transformer), |
| 136 | + batch_size, shuffle=True, last_batch='discard') |
| 137 | + |
| 138 | + val_data = gluon.data.DataLoader( |
| 139 | + gluon.data.vision.MNIST('./data', train=False, transform=transformer), |
| 140 | + batch_size, shuffle=False) |
| 141 | + # cifar10 |
| 142 | + elif dataset == "cifar10": |
| 143 | + train_data = gluon.data.DataLoader( |
| 144 | + gluon.data.vision.CIFAR10('./data', train=True, transform=transformer), |
| 145 | + batch_size, shuffle=True, last_batch='discard') |
| 146 | + |
| 147 | + val_data = gluon.data.DataLoader( |
| 148 | + gluon.data.vision.CIFAR10('./data', train=False, transform=transformer), |
| 149 | + batch_size, shuffle=False) |
| 150 | + |
| 151 | + return train_data, val_data |
| 152 | + |
| 153 | + |
| 154 | +def get_netG(): |
| 155 | + # build the generator |
| 156 | + netG = nn.Sequential() |
| 157 | + with netG.name_scope(): |
| 158 | + # input is Z, going into a convolution |
| 159 | + netG.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False)) |
| 160 | + netG.add(nn.BatchNorm()) |
| 161 | + netG.add(nn.Activation('relu')) |
| 162 | + # state size. (ngf*8) x 4 x 4 |
| 163 | + netG.add(nn.Conv2DTranspose(ngf * 4, 4, 2, 1, use_bias=False)) |
| 164 | + netG.add(nn.BatchNorm()) |
| 165 | + netG.add(nn.Activation('relu')) |
| 166 | + # state size. (ngf*4) x 8 x 8 |
| 167 | + netG.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False)) |
| 168 | + netG.add(nn.BatchNorm()) |
| 169 | + netG.add(nn.Activation('relu')) |
| 170 | + # state size. (ngf*2) x 16 x 16 |
| 171 | + netG.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False)) |
| 172 | + netG.add(nn.BatchNorm()) |
| 173 | + netG.add(nn.Activation('relu')) |
| 174 | + # state size. (ngf) x 32 x 32 |
| 175 | + netG.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False)) |
| 176 | + netG.add(nn.Activation('tanh')) |
| 177 | + # state size. (nc) x 64 x 64 |
| 178 | + |
| 179 | + return netG |
| 180 | + |
| 181 | + |
| 182 | +def get_netD(): |
| 183 | + # build the discriminator |
| 184 | + netD = nn.Sequential() |
| 185 | + with netD.name_scope(): |
| 186 | + # input is (nc) x 64 x 64 |
| 187 | + netD.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False)) |
| 188 | + netD.add(nn.LeakyReLU(0.2)) |
| 189 | + # state size. (ndf) x 32 x 32 |
| 190 | + netD.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False)) |
| 191 | + netD.add(nn.BatchNorm()) |
| 192 | + netD.add(nn.LeakyReLU(0.2)) |
| 193 | + # state size. (ndf*2) x 16 x 16 |
| 194 | + netD.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False)) |
| 195 | + netD.add(nn.BatchNorm()) |
| 196 | + netD.add(nn.LeakyReLU(0.2)) |
| 197 | + # state size. (ndf*4) x 8 x 8 |
| 198 | + netD.add(nn.Conv2D(ndf * 8, 4, 2, 1, use_bias=False)) |
| 199 | + netD.add(nn.BatchNorm()) |
| 200 | + netD.add(nn.LeakyReLU(0.2)) |
| 201 | + # state size. (ndf*8) x 4 x 4 |
| 202 | + netD.add(nn.Conv2D(2, 4, 1, 0, use_bias=False)) |
| 203 | + # state size. 2 x 1 x 1 |
| 204 | + |
| 205 | + return netD |
| 206 | + |
| 207 | + |
| 208 | +def get_configurations(netG, netD): |
| 209 | + # loss |
| 210 | + loss = gluon.loss.SoftmaxCrossEntropyLoss() |
| 211 | + |
| 212 | + # initialize the generator and the discriminator |
| 213 | + netG.initialize(mx.init.Normal(0.02), ctx=ctx) |
| 214 | + netD.initialize(mx.init.Normal(0.02), ctx=ctx) |
| 215 | + |
| 216 | + # trainer for the generator and the discriminator |
| 217 | + trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': opt.lr, 'beta1': opt.beta1}) |
| 218 | + trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': opt.lr, 'beta1': opt.beta1}) |
| 219 | + |
| 220 | + return loss, trainerG, trainerD |
| 221 | + |
| 222 | + |
| 223 | +def ins_save(inception_score): |
| 224 | + # draw the inception_score curve |
| 225 | + length = len(inception_score) |
| 226 | + x = np.arange(0, length) |
| 227 | + plt.figure(figsize=(8.0, 6.0)) |
| 228 | + plt.plot(x, inception_score) |
| 229 | + plt.xlabel("iter/100") |
| 230 | + plt.ylabel("inception_score") |
| 231 | + plt.savefig("inception_score.png") |
| 232 | + |
| 233 | + |
| 234 | +# main function |
| 235 | +def main(): |
| 236 | + print("|------- new changes!!!!!!!!!") |
| 237 | + # to get the dataset and net configuration |
| 238 | + train_data, val_data = get_dataset(dataset) |
| 239 | + netG = get_netG() |
| 240 | + netD = get_netD() |
| 241 | + loss, trainerG, trainerD = get_configurations(netG, netD) |
| 242 | + |
| 243 | + # set labels |
| 244 | + real_label = mx.nd.ones((opt.batch_size,), ctx=ctx) |
| 245 | + fake_label = mx.nd.zeros((opt.batch_size,), ctx=ctx) |
| 246 | + |
| 247 | + metric = mx.metric.Accuracy() |
| 248 | + print('Training... ') |
| 249 | + stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') |
| 250 | + |
| 251 | + iter = 0 |
| 252 | + |
| 253 | + # to metric the network |
| 254 | + loss_d = [] |
| 255 | + loss_g = [] |
| 256 | + inception_score = [] |
| 257 | + |
| 258 | + for epoch in range(opt.nepoch): |
| 259 | + tic = time.time() |
| 260 | + btic = time.time() |
| 261 | + for data, _ in train_data: |
| 262 | + ############################ |
| 263 | + # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) |
| 264 | + ########################### |
| 265 | + # train with real_t |
| 266 | + data = data.as_in_context(ctx) |
| 267 | + noise = mx.nd.random.normal(0, 1, shape=(opt.batch_size, nz, 1, 1), ctx=ctx) |
| 268 | + |
| 269 | + with autograd.record(): |
| 270 | + output = netD(data) |
| 271 | + # reshape output from (opt.batch_size, 2, 1, 1) to (opt.batch_size, 2) |
| 272 | + output = output.reshape((opt.batch_size, 2)) |
| 273 | + errD_real = loss(output, real_label) |
| 274 | + |
| 275 | + metric.update([real_label, ], [output, ]) |
| 276 | + |
| 277 | + with autograd.record(): |
| 278 | + fake = netG(noise) |
| 279 | + output = netD(fake.detach()) |
| 280 | + output = output.reshape((opt.batch_size, 2)) |
| 281 | + errD_fake = loss(output, fake_label) |
| 282 | + errD = errD_real + errD_fake |
| 283 | + |
| 284 | + errD.backward() |
| 285 | + metric.update([fake_label,], [output,]) |
| 286 | + |
| 287 | + trainerD.step(opt.batch_size) |
| 288 | + |
| 289 | + ############################ |
| 290 | + # (2) Update G network: maximize log(D(G(z))) |
| 291 | + ########################### |
| 292 | + with autograd.record(): |
| 293 | + output = netD(fake) |
| 294 | + output = output.reshape((-1, 2)) |
| 295 | + errG = loss(output, real_label) |
| 296 | + |
| 297 | + errG.backward() |
| 298 | + |
| 299 | + trainerG.step(opt.batch_size) |
| 300 | + |
| 301 | + name, acc = metric.get() |
| 302 | + logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d' |
| 303 | + % (mx.nd.mean(errD).asscalar(), mx.nd.mean(errG).asscalar(), acc, iter, epoch)) |
| 304 | + if iter % niter == 0: |
| 305 | + visual('gout', fake.asnumpy(), name=os.path.join(outf, 'fake_img_iter_%d.png' % iter)) |
| 306 | + visual('data', data.asnumpy(), name=os.path.join(outf, 'real_img_iter_%d.png' % iter)) |
| 307 | + # record the metric data |
| 308 | + loss_d.append(errD) |
| 309 | + loss_g.append(errG) |
| 310 | + if opt.inception_score: |
| 311 | + score, _ = get_inception_score(fake) |
| 312 | + inception_score.append(score) |
| 313 | + |
| 314 | + iter = iter + 1 |
| 315 | + btic = time.time() |
| 316 | + |
| 317 | + name, acc = metric.get() |
| 318 | + metric.reset() |
| 319 | + logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) |
| 320 | + logging.info('time: %f' % (time.time() - tic)) |
| 321 | + |
| 322 | + # save check_point |
| 323 | + if check_point: |
| 324 | + netG.save_parameters(os.path.join(outf,'generator_epoch_%d.params' %epoch)) |
| 325 | + netD.save_parameters(os.path.join(outf,'discriminator_epoch_%d.params' % epoch)) |
| 326 | + |
| 327 | + # save parameter |
| 328 | + netG.save_parameters(os.path.join(outf, 'generator.params')) |
| 329 | + netD.save_parameters(os.path.join(outf, 'discriminator.params')) |
| 330 | + |
| 331 | + # visualization the inception_score as a picture |
| 332 | + if opt.inception_score: |
| 333 | + ins_save(inception_score) |
| 334 | + |
| 335 | + |
| 336 | +if __name__ == '__main__': |
| 337 | + if opt.inception_score: |
| 338 | + print("Use inception_score to metric this DCgan model, the reusult is save as a picture named \"inception_score.png\"!") |
| 339 | + main() |
| 340 | + |
0 commit comments