DCGANで数字を書く(MXNet)

はじめに

MNISTデータを学習させて数字を書かせる。
今回は単純に数字の「5」だけを書かせる。

参考にさせて頂いたサイト

aidiary.hatenablog.com
github.com

環境

Windows10 Pro 64bit
NVIDIA GeForce GTX1080
CUDA9.2
cudnn7.2.1
Python3.6.6(venv使用)

astroid==2.0.4
certifi==2018.8.24
chardet==3.0.4
colorama==0.3.9
graphviz==0.8.4
idna==2.6
isort==4.3.4
lazy-object-proxy==1.3.1
mccabe==0.6.1
mxnet-cu92==1.3.1b20180927
numpy==1.14.6
Pillow==5.2.0
pylint==2.1.1
requests==2.18.4
six==1.11.0
typed-ast==1.1.0
urllib3==1.22
wrapt==1.10.11

モデル(gan_model.py)

import mxnet as mx
from mxnet.gluon import Block, nn, rnn

class modelG(Block):
    def __init__(self, **kwargs):
        super(modelG, self).__init__(**kwargs)
        with self.name_scope():
            self.dense1 = nn.Dense(1024)
            self.batch1 = nn.BatchNorm()
            self.dense2 = nn.Dense(6272)
            self.batch2 = nn.BatchNorm()
            self.tconv2d1 =nn.Conv2DTranspose(64, kernel_size=4, strides=2, padding=1, use_bias=False)
            self.batch3 = nn.BatchNorm()
            self.tconv2d2 =nn.Conv2DTranspose(1, kernel_size=4, strides=2, padding=1, use_bias=False)
    def forward(self, x):
        output = self.dense1(x)
        output = self.batch1(output)
        output = mx.nd.relu(output)
        output = self.dense2(output)
        output = self.batch2(output)
        output = mx.nd.relu(output)
        output = output.reshape(-1,128,7,7)
        output = self.tconv2d1(output)
        output = self.batch3(output)
        output = mx.nd.relu(output)
        output = self.tconv2d2(output)
        output = mx.nd.sigmoid(output)
        return output

class modelD(Block):
    def __init__(self, **kwargs):
        super(modelD, self).__init__(**kwargs)
        with self.name_scope():
            self.conv2d1 = nn.Conv2D(64, kernel_size=4, strides=2, padding=1, use_bias=False)
            self.lrelu1 = nn.LeakyReLU(alpha=0.2)
            self.conv2d2 = nn.Conv2D(128, kernel_size=4, strides=2, padding=1, use_bias=False)
            self.batch_1 = nn.BatchNorm()
            self.lrelu2 = nn.LeakyReLU(alpha=0.2)
            self.dense_1 = nn.Dense(1024)
            self.batch_2 = nn.BatchNorm()
            self.lrelu3 = nn.LeakyReLU(alpha=0.2)
            self.dense_2 = nn.Dense(1)
    def forward(self, x):
        output = self.conv2d1(x)
        output = self.lrelu1(output)
        output = self.conv2d2(output)
        output = self.batch_1(output)
        output = self.lrelu2(output)
        output = self.dense_1(output)
        output = self.batch_2(output)
        output = self.lrelu3(output)
        output = self.dense_2(output)
        output = mx.nd.sigmoid(output)
        return output

実行ファイル

import mxnet as mx
from mxnet import gluon, autograd
import numpy as np
from PIL import Image

mnist = mx.test_utils.get_mnist()
x_train = mnist['train_data']
t_train = mnist['train_label']

indx = np.where(t_train==5)
x_train = x_train[indx]

x_train = mx.nd.array(x_train)

import gan_model
modelD = gan_model.modelD()
modelD.initialize(mx.init.Normal(0.02), ctx=mx.gpu())
modelG = gan_model.modelG()
modelG.initialize(mx.init.Normal(0.02), ctx=mx.gpu())

trainerD = gluon.Trainer(modelD.collect_params(), 'adam')
trainerG = gluon.Trainer(modelG.collect_params(), 'adam')

lossD = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=True)

# hyperparameters
batch_size = 100
epochs = 25

loss_d = [] #ログ表示用
loss_g = []

for epoch in range(1, epochs + 1):
    #ランダムに並べ替えたインデックスを作成
    indexs = np.random.permutation(x_train.shape[0])
    cur_start = 0
    while cur_start < x_train.shape[0]:
        cur_end = (cur_start + batch_size) if (cur_start + batch_size) < x_train.shape[0] else x_train.shape[0]
        data = x_train[indexs[cur_start:cur_end]].as_in_context(mx.gpu())
        label = mx.nd.ones(data.shape[0], ctx=mx.gpu())
        noise = mx.nd.random.uniform(0,1,shape=(data.shape[0],100), ctx=mx.gpu())
        # update D
        with autograd.record():            
            output = modelD(data)
            D_error_real = lossD(output,label)
            fake_image = modelG(noise)
            output = modelD(fake_image.detach())
            D_error_fake = lossD(output, label*0)
            D_error = D_error_real + D_error_fake
            loss_d.append(np.mean(D_error.asnumpy()))
            D_error.backward()
        #学習ステータスをデータサイズ分進める
        trainerD.step(data.shape[0])

        # updata G
        noise = mx.nd.random.uniform(0,1,shape=(data.shape[0],100), ctx=mx.gpu())
        with autograd.record():
            fake_image = modelG(noise)
            output = modelD(fake_image)
            G_error = lossD(output,label)
            loss_g.append(np.mean(G_error.asnumpy()))
            G_error.backward()
        trainerG.step(data.shape[0])
        cur_start = cur_end
    #ログを表示
    ll_d = np.mean(loss_d)
    ll_g = np.mean(loss_g)
    print('%d epoch G_loss = %f D_loss = %f' %(epoch, ll_g, ll_d))
    loss_d = []
    loss_g = []
    if (epoch % 5)==0:
        noise = mx.nd.random.uniform(0,1,shape=(1,100), ctx=mx.gpu())
        image = modelG(noise)
        image = Image.fromarray(image[0][0].asnumpy()*255)
        image = image.convert('L')
        image.save('epoch%d.png'%epoch)

modelD.save_parameters('modelD.params')
modelG.save_parameters('modelG.params')

結果

約1分くらいで学習が終わる。

1 epoch G_loss = 0.777134 D_loss = 1.376519
2 epoch G_loss = 1.097475 D_loss = 0.958754
3 epoch G_loss = 1.474690 D_loss = 0.758916
4 epoch G_loss = 1.602817 D_loss = 0.771593
5 epoch G_loss = 1.678429 D_loss = 0.779117
6 epoch G_loss = 1.931964 D_loss = 0.788672
7 epoch G_loss = 2.244472 D_loss = 0.575185
8 epoch G_loss = 1.683377 D_loss = 0.910790
9 epoch G_loss = 1.919621 D_loss = 0.645680
10 epoch G_loss = 1.739710 D_loss = 1.013023
11 epoch G_loss = 1.705402 D_loss = 0.821467
12 epoch G_loss = 1.961140 D_loss = 0.780046
13 epoch G_loss = 1.573381 D_loss = 0.965425
14 epoch G_loss = 1.935019 D_loss = 0.806350
15 epoch G_loss = 2.063324 D_loss = 0.681775
16 epoch G_loss = 1.585250 D_loss = 0.917888
17 epoch G_loss = 1.395943 D_loss = 0.929869
18 epoch G_loss = 1.585122 D_loss = 0.944633
19 epoch G_loss = 1.609372 D_loss = 0.867665
20 epoch G_loss = 1.598825 D_loss = 0.868693
21 epoch G_loss = 1.390774 D_loss = 1.091793
22 epoch G_loss = 1.514822 D_loss = 0.900835
23 epoch G_loss = 1.540376 D_loss = 0.928998
24 epoch G_loss = 1.609747 D_loss = 0.955562
25 epoch G_loss = 1.333527 D_loss = 1.090823

5, 10, 15, 20, 25エポック終了時に書かせた図
f:id:touch-sp:20181002235549p:plain
f:id:touch-sp:20181002235603p:plain
f:id:touch-sp:20181002235609p:plain
f:id:touch-sp:20181002235618p:plain
f:id:touch-sp:20181002235625p:plain