初めに
こちらのコードを自分なりに書き換えてみる
Deep Convolutional Generative Adversarial Networks — The Straight Dope 0.1 documentation
環境
Windows10 Pro 64bit
NVIDIA GeForce GTX1080
CUDA9.2
cudnn7.2.1
Python3.6.6(venv使用)
astroid==2.0.4 certifi==2018.8.24 chardet==3.0.4 colorama==0.3.9 graphviz==0.8.4 idna==2.6 isort==4.3.4 lazy-object-proxy==1.3.1 mccabe==0.6.1 mxnet-cu92==1.3.1b20180927 numpy==1.14.6 Pillow==5.2.0 pylint==2.1.1 requests==2.18.4 six==1.11.0 typed-ast==1.1.0 urllib3==1.22 wrapt==1.10.11
データの取得
- データをダウンロードしてNDArrayに変換した後にpickleで保存しておく
- この先はこのデータを使用する
- mxnet.gluonのutils.downloadを使うとrequestsより楽に書ける
import os import tarfile import numpy as np import pickle import mxnet as mx from mxnet.gluon import utils #Download and preprocess the LWF Face Dataset lfw_url = 'http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz' data_path = 'lfw_dataset' if not os.path.exists(data_path): os.makedirs(data_path) data_file = utils.download(lfw_url) with tarfile.open(data_file) as tar: tar.extractall(path=data_path) def transform(data): #resize images to size 64×64 data = mx.image.imresize(data, 64, 64) #transpose from (64, 64, 3) to (3, 64, 64) data = mx.nd.transpose(data, (2,0,1)) #normalize all pixel values to the [-1, 1] range data = data.astype(np.float32)/127.5 - 1 # if image is greyscale, repeat 3 times to get RGB image. if data.shape[0] == 1: data = mx.nd.tile(data, (3, 1, 1)) return data.reshape((1,) + data.shape) img_list = [] for path, _, fnames in os.walk(data_path): for fname in fnames: if not fname.endswith('.jpg'): continue img = os.path.join(path, fname) img_arr = mx.image.imread(img) img_arr = transform(img_arr) img_list.append(img_arr) with open('img_list.pickle', 'wb') as f: pickle.dump(img_list, f)
画像を表示(必ずしも必要でない)
from matplotlib import pyplot as plt import pickle import numpy as np with open('img_list.pickle', 'rb') as f: img_list = pickle.load(f) def visualize(img_arr): plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)) plt.axis('off') for i in range(4): plt.subplot(1,4,i+1) visualize(img_list[i + 10][0]) plt.show()
モデル(dcgan_model.py)
from mxnet.gluon import nn def netG(): model = nn.Sequential() with model.name_scope(): model.add(nn.Conv2DTranspose(512, kernel_size=4, strides=1, padding=0, use_bias=False)) model.add(nn.BatchNorm()) model.add(nn.Activation('relu')) model.add(nn.Conv2DTranspose(256, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm()) model.add(nn.Activation('relu')) model.add(nn.Conv2DTranspose(128, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm()) model.add(nn.Activation('relu')) model.add(nn.Conv2DTranspose(64, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm()) model.add(nn.Activation('relu')) model.add(nn.Conv2DTranspose(3, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.Activation('tanh')) return model def netD(): model = nn.Sequential() with model.name_scope(): model.add(nn.Conv2D(64, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.LeakyReLU(alpha=0.2)) model.add(nn.Conv2D(128, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm()) model.add(nn.LeakyReLU(alpha=0.2)) model.add(nn.Conv2D(256, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm()) model.add(nn.LeakyReLU(alpha=0.2)) model.add(nn.Conv2D(512, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm()) model.add(nn.LeakyReLU(alpha=0.2)) model.add(nn.Conv2D(1, kernel_size=4, strides=1, padding=0, use_bias=False)) return model
実行ファイル
import pickle import numpy as np from PIL import Image import mxnet as mx from mxnet import gluon, autograd #パラメーターの設定 ctx = mx.gpu() epochs = 30 batch_size = 64 lr = 0.0002 beta1 = 0.5 #データの読み込み with open('img_list.pickle', 'rb') as f: img_list = pickle.load(f) train_data = mx.io.NDArrayIter(data=mx.nd.concatenate(img_list), batch_size=batch_size) #モデルの読み込み、初期化 import dcgan_model netG = dcgan_model.netG() netG.initialize(mx.init.Normal(0.02), ctx=ctx) netD = dcgan_model.netD() netD.initialize(mx.init.Normal(0.02), ctx=ctx) loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1}) trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1}) #学習の開始 real_label = mx.nd.ones((batch_size,), ctx=ctx) fake_label = mx.nd.zeros((batch_size,),ctx=ctx) loss_d = [] loss_g = [] print('start training...') for epoch in range(1, epochs + 1): train_data.reset() for batch in train_data: ############################ # (1) Update D network ########################### data = batch.data[0].as_in_context(ctx) noise = mx.nd.random_normal(0, 1, shape=(batch_size, 100, 1, 1), ctx=ctx) with autograd.record(): # train with real image output = netD(data).reshape((-1, 1)) errD_real = loss(output, real_label) # train with fake image fake = netG(noise) output = netD(fake.detach()).reshape((-1, 1)) errD_fake = loss(output, fake_label) errD = errD_real + errD_fake loss_d.append(np.mean(errD.asnumpy())) errD.backward() trainerD.step(batch.data[0].shape[0]) ############################ # (2) Update G network ########################### noise = mx.nd.random_normal(0, 1, shape=(batch_size, 100, 1, 1), ctx=ctx) with autograd.record(): fake = netG(noise) output = netD(fake).reshape((-1, 1)) errG = loss(output, real_label) loss_g.append(np.mean(errG.asnumpy())) errG.backward() trainerG.step(batch.data[0].shape[0]) #ログを表示 ll_d = np.mean(loss_d) ll_g = np.mean(loss_g) print('%d epoch G_loss = %f D_loss = %f' %(epoch, ll_g, ll_d)) loss_d = [] loss_g = [] #5epoch毎に画像を保存 if (epoch % 5)==0: noise = mx.nd.random_normal(0, 1, shape=(1, 100, 1, 1), ctx=ctx) output = netG(noise) img_array = ((output[0].asnumpy().transpose(1,2,0)+1.0)*127.5).astype(np.uint8) image = Image.fromarray(img_array) image.save('epoch%d.png'%epoch) netD.save_parameters('netD.params') netG.save_parameters('netG.params')
結果
約5分くらいで学習が終わった。
1 epoch G_loss = 9.646337 D_loss = 0.880275 2 epoch G_loss = 5.208412 D_loss = 0.684651 3 epoch G_loss = 5.314768 D_loss = 0.602023 4 epoch G_loss = 5.147544 D_loss = 0.571834 5 epoch G_loss = 5.076499 D_loss = 0.515328 6 epoch G_loss = 4.932439 D_loss = 0.476169 7 epoch G_loss = 4.703326 D_loss = 0.508827 8 epoch G_loss = 4.664301 D_loss = 0.516196 9 epoch G_loss = 4.782000 D_loss = 0.494669 10 epoch G_loss = 4.648095 D_loss = 0.470350 11 epoch G_loss = 4.602288 D_loss = 0.523305 12 epoch G_loss = 4.258514 D_loss = 0.462371 13 epoch G_loss = 4.159195 D_loss = 0.583787 14 epoch G_loss = 3.770833 D_loss = 0.548336 15 epoch G_loss = 3.758740 D_loss = 0.637577 16 epoch G_loss = 3.534543 D_loss = 0.545054 17 epoch G_loss = 3.526036 D_loss = 0.617458 18 epoch G_loss = 3.369482 D_loss = 0.644616 19 epoch G_loss = 3.240339 D_loss = 0.574946 20 epoch G_loss = 3.336135 D_loss = 0.623973 21 epoch G_loss = 3.226388 D_loss = 0.686673 22 epoch G_loss = 3.134521 D_loss = 0.515629 23 epoch G_loss = 3.126158 D_loss = 0.651433 24 epoch G_loss = 3.147122 D_loss = 0.649066 25 epoch G_loss = 3.103747 D_loss = 0.595519 26 epoch G_loss = 3.068046 D_loss = 0.611841 27 epoch G_loss = 3.065959 D_loss = 0.592803 28 epoch G_loss = 3.071336 D_loss = 0.645391 29 epoch G_loss = 3.039070 D_loss = 0.585652 30 epoch G_loss = 3.066867 D_loss = 0.568904
5, 10, 15, 20, 25, 30エポック終了時に書かせた図