DCGANで顔を書く(MXNet)

初めに

こちらのコードを自分なりに書き換えてみる
Deep Convolutional Generative Adversarial Networks — The Straight Dope 0.1 documentation

環境

Windows10 Pro 64bit
NVIDIA GeForce GTX1080
CUDA9.2
cudnn7.2.1
Python3.6.6(venv使用)

astroid==2.0.4
certifi==2018.8.24
chardet==3.0.4
colorama==0.3.9
graphviz==0.8.4
idna==2.6
isort==4.3.4
lazy-object-proxy==1.3.1
mccabe==0.6.1
mxnet-cu92==1.3.1b20180927
numpy==1.14.6
Pillow==5.2.0
pylint==2.1.1
requests==2.18.4
six==1.11.0
typed-ast==1.1.0
urllib3==1.22
wrapt==1.10.11

データの取得

  • データをダウンロードしてNDArrayに変換した後にpickleで保存しておく
  • この先はこのデータを使用する
  • mxnet.gluonのutils.downloadを使うとrequestsより楽に書ける
import os
import tarfile
import numpy as np
import pickle
import mxnet as mx
from mxnet.gluon import utils

#Download and preprocess the LWF Face Dataset
lfw_url = 'http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz'
data_path = 'lfw_dataset'
if not os.path.exists(data_path):
    os.makedirs(data_path)
    data_file = utils.download(lfw_url)
    with tarfile.open(data_file) as tar:
        tar.extractall(path=data_path)

def transform(data):
    #resize images to size 64×64
    data = mx.image.imresize(data, 64, 64)
    #transpose from (64, 64, 3) to (3, 64, 64)
    data = mx.nd.transpose(data, (2,0,1))
    #normalize all pixel values to the [-1, 1] range
    data = data.astype(np.float32)/127.5 - 1
    # if image is greyscale, repeat 3 times to get RGB image.
    if data.shape[0] == 1:
        data = mx.nd.tile(data, (3, 1, 1))
    return data.reshape((1,) + data.shape)

img_list = []
for path, _, fnames in os.walk(data_path):
    for fname in fnames:
        if not fname.endswith('.jpg'):
            continue
        img = os.path.join(path, fname)
        img_arr = mx.image.imread(img)
        img_arr = transform(img_arr)
        img_list.append(img_arr)

with open('img_list.pickle', 'wb') as f:
    pickle.dump(img_list, f)

画像を表示(必ずしも必要でない)

from matplotlib import pyplot as plt
import pickle
import numpy as np

with open('img_list.pickle', 'rb') as f:
    img_list = pickle.load(f)

def visualize(img_arr):
    plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8))
    plt.axis('off')

for i in range(4):
    plt.subplot(1,4,i+1)
    visualize(img_list[i + 10][0])
plt.show()

f:id:touch-sp:20181003121717p:plain

モデル(dcgan_model.py)

from mxnet.gluon import nn

def netG():
    model = nn.Sequential()
    with model.name_scope():
        model.add(nn.Conv2DTranspose(512, kernel_size=4, strides=1, padding=0, use_bias=False))
        model.add(nn.BatchNorm())
        model.add(nn.Activation('relu')) 
        model.add(nn.Conv2DTranspose(256, kernel_size=4, strides=2, padding=1, use_bias=False)) 
        model.add(nn.BatchNorm())
        model.add(nn.Activation('relu'))
        model.add(nn.Conv2DTranspose(128, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.BatchNorm())
        model.add(nn.Activation('relu'))
        model.add(nn.Conv2DTranspose(64, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.BatchNorm())
        model.add(nn.Activation('relu'))
        model.add(nn.Conv2DTranspose(3, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.Activation('tanh'))
    return model

def netD():
    model = nn.Sequential()
    with model.name_scope():
        model.add(nn.Conv2D(64, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.LeakyReLU(alpha=0.2))
        model.add(nn.Conv2D(128, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.BatchNorm())
        model.add(nn.LeakyReLU(alpha=0.2))
        model.add(nn.Conv2D(256, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.BatchNorm())
        model.add(nn.LeakyReLU(alpha=0.2))
        model.add(nn.Conv2D(512, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.BatchNorm())
        model.add(nn.LeakyReLU(alpha=0.2))
        model.add(nn.Conv2D(1, kernel_size=4, strides=1, padding=0, use_bias=False))
    return model

実行ファイル

import pickle
import numpy as np
from PIL import Image

import mxnet as mx
from mxnet import gluon, autograd

#パラメーターの設定
ctx = mx.gpu()
epochs = 30
batch_size = 64
lr = 0.0002
beta1 = 0.5

#データの読み込み
with open('img_list.pickle', 'rb') as f:
    img_list = pickle.load(f)

train_data = mx.io.NDArrayIter(data=mx.nd.concatenate(img_list), batch_size=batch_size)

#モデルの読み込み、初期化
import dcgan_model
netG = dcgan_model.netG()
netG.initialize(mx.init.Normal(0.02), ctx=ctx)
netD = dcgan_model.netD()
netD.initialize(mx.init.Normal(0.02), ctx=ctx)

loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()

trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})

#学習の開始
real_label = mx.nd.ones((batch_size,), ctx=ctx)
fake_label = mx.nd.zeros((batch_size,),ctx=ctx)

loss_d = []
loss_g = []

print('start training...')

for epoch in range(1, epochs + 1):
    train_data.reset()
    for batch in train_data:
        ############################
        # (1) Update D network
        ###########################
        data = batch.data[0].as_in_context(ctx)
        noise = mx.nd.random_normal(0, 1, shape=(batch_size, 100, 1, 1), ctx=ctx)
        with autograd.record():
            # train with real image
            output = netD(data).reshape((-1, 1))
            errD_real = loss(output, real_label)
            # train with fake image
            fake = netG(noise)
            output = netD(fake.detach()).reshape((-1, 1))
            errD_fake = loss(output, fake_label)
            errD = errD_real + errD_fake
            loss_d.append(np.mean(errD.asnumpy()))
            errD.backward()
        trainerD.step(batch.data[0].shape[0])

        ############################
        # (2) Update G network
        ###########################
        noise = mx.nd.random_normal(0, 1, shape=(batch_size, 100, 1, 1), ctx=ctx)
        with autograd.record():
            fake = netG(noise)
            output = netD(fake).reshape((-1, 1))
            errG = loss(output, real_label)
            loss_g.append(np.mean(errG.asnumpy()))
            errG.backward()
        trainerG.step(batch.data[0].shape[0])
    
    #ログを表示
    ll_d = np.mean(loss_d)
    ll_g = np.mean(loss_g)
    print('%d epoch G_loss = %f D_loss = %f' %(epoch, ll_g, ll_d))
    loss_d = []
    loss_g = []
    #5epoch毎に画像を保存
    if (epoch % 5)==0:
        noise = mx.nd.random_normal(0, 1, shape=(1, 100, 1, 1), ctx=ctx)
        output = netG(noise)
        img_array = ((output[0].asnumpy().transpose(1,2,0)+1.0)*127.5).astype(np.uint8)
        image = Image.fromarray(img_array)
        image.save('epoch%d.png'%epoch)

netD.save_parameters('netD.params')
netG.save_parameters('netG.params')

結果

約5分くらいで学習が終わった。

1 epoch G_loss = 9.646337 D_loss = 0.880275
2 epoch G_loss = 5.208412 D_loss = 0.684651
3 epoch G_loss = 5.314768 D_loss = 0.602023
4 epoch G_loss = 5.147544 D_loss = 0.571834
5 epoch G_loss = 5.076499 D_loss = 0.515328
6 epoch G_loss = 4.932439 D_loss = 0.476169
7 epoch G_loss = 4.703326 D_loss = 0.508827
8 epoch G_loss = 4.664301 D_loss = 0.516196
9 epoch G_loss = 4.782000 D_loss = 0.494669
10 epoch G_loss = 4.648095 D_loss = 0.470350
11 epoch G_loss = 4.602288 D_loss = 0.523305
12 epoch G_loss = 4.258514 D_loss = 0.462371
13 epoch G_loss = 4.159195 D_loss = 0.583787
14 epoch G_loss = 3.770833 D_loss = 0.548336
15 epoch G_loss = 3.758740 D_loss = 0.637577
16 epoch G_loss = 3.534543 D_loss = 0.545054
17 epoch G_loss = 3.526036 D_loss = 0.617458
18 epoch G_loss = 3.369482 D_loss = 0.644616
19 epoch G_loss = 3.240339 D_loss = 0.574946
20 epoch G_loss = 3.336135 D_loss = 0.623973
21 epoch G_loss = 3.226388 D_loss = 0.686673
22 epoch G_loss = 3.134521 D_loss = 0.515629
23 epoch G_loss = 3.126158 D_loss = 0.651433
24 epoch G_loss = 3.147122 D_loss = 0.649066
25 epoch G_loss = 3.103747 D_loss = 0.595519
26 epoch G_loss = 3.068046 D_loss = 0.611841
27 epoch G_loss = 3.065959 D_loss = 0.592803
28 epoch G_loss = 3.071336 D_loss = 0.645391
29 epoch G_loss = 3.039070 D_loss = 0.585652
30 epoch G_loss = 3.066867 D_loss = 0.568904

5, 10, 15, 20, 25, 30エポック終了時に書かせた図
f:id:touch-sp:20181004002350p:plain
f:id:touch-sp:20181004002358p:plain
f:id:touch-sp:20181004002405p:plain
f:id:touch-sp:20181004002414p:plain
f:id:touch-sp:20181004002423p:plain
f:id:touch-sp:20181004002432p:plain