MXNetでpix2pix

2018年8月22日動作確認

環境

Windows10 Pro 64bit
NVIDIA GeForce GTX1080
CUDA9.2
cudnn7.2.1

はじめに(注意)

Pixel to Pixel Generative Adversarial Networks — The Straight Dope 0.1 documentation
こちらのコードを少し変更しただけで、オリジナルではありません

Anacondaで仮想環境を作成

conda create -n mxnet python=3.6 anaconda
activate mxnet
  • pipのアップデート
python -m pip install --upgrade pip
  • msgpackのインストール
pip install msgpack

MXNetのインストール

pip install mxnet-cu92==1.3.0b20180820

「data_download.py」を作成して実行

import os
import tarfile
from mxnet.gluon import utils

dataset = 'facades'

def download_data(dataset):
    if not os.path.exists(dataset):
        url = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/%s.tar.gz' % (dataset)
        os.mkdir(dataset)
        data_file = utils.download(url)
        with tarfile.open(data_file) as tar:
            tar.extractall(path='.')
        os.remove(data_file)

download_data(dataset)
python data_download.py
  • 拡張子「tar.gz」の圧縮ファイルはWindowsで解凍できないと思っていたが問題なく実行できた

f:id:touch-sp:20180808152936j:plain:w300

  • ダウンロードされるファイルは512×256のjpegファイルであった
  • その他のサイズのファイルを自分で準備しても「mx.image.imresize」で後にそのサイズに変換されるようなコードになっている
  • inputファイルとoutputファイルが左右に結合されているが、結局は「mx.image.fixed_crop」で後に分割する
  • 左右どちらをinputファイルにするかはデータを読み込む際に「is_reversed」で指定できる

「Utility.py」の作成

import os
from PIL import Image

import mxnet as mx
from mxnet import ndarray as nd
import numpy as np

def load_data(path, batch_size, is_reversed=False):

    img_wd = 256
    img_ht = 256

    img_in_list = []
    img_out_list = []

    for path, _, fnames in os.walk(path):
        for fname in fnames:
            if not fname.endswith('.jpg'):
                continue
            img = os.path.join(path, fname)
            img_arr = mx.image.imread(img).astype(np.float32)/127.5 - 1
            img_arr = mx.image.imresize(img_arr, img_wd * 2, img_ht)
            # Crop input and output images
            img_arr_in, img_arr_out = [mx.image.fixed_crop(img_arr, 0, 0, img_wd, img_ht),
                                       mx.image.fixed_crop(img_arr, img_wd, 0, img_wd, img_ht)]
            img_arr_in, img_arr_out = [nd.transpose(img_arr_in, (2,0,1)), 
                                       nd.transpose(img_arr_out, (2,0,1))]
            img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape), 
                                       img_arr_out.reshape((1,) + img_arr_out.shape)]
            img_in_list.append(img_arr_out if is_reversed else img_arr_in)
            img_out_list.append(img_arr_in if is_reversed else img_arr_out)

    return mx.io.NDArrayIter(data=[nd.concat(*img_in_list, dim=0), nd.concat(*img_out_list, dim=0)], 
                             batch_size=batch_size)

def save_img(img_arr,filename):
    
    img_out = ((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)
    pilImg = Image.fromarray(img_out)
    pilImg.save(filename, "JPEG", quality=100)

「Model.py」の作成

import os

import mxnet as mx
from mxnet import gluon
from mxnet import ndarray as nd
from mxnet.gluon.nn import Dense, Activation, Conv2D, Conv2DTranspose, BatchNorm, LeakyReLU, Flatten, HybridSequential, HybridBlock, Dropout
from mxnet import autograd
import numpy as np

# Define Unet generator skip block
class UnetSkipUnit(HybridBlock):
    def __init__(self, inner_channels, outer_channels, inner_block=None, innermost=False, outermost=False,
                 use_dropout=False, use_bias=False):
        super(UnetSkipUnit, self).__init__()

        with self.name_scope():
            self.outermost = outermost
            en_conv = Conv2D(channels=inner_channels, kernel_size=4, strides=2, padding=1,
                             in_channels=outer_channels, use_bias=use_bias)
            en_relu = LeakyReLU(alpha=0.2)
            en_norm = BatchNorm(momentum=0.1, in_channels=inner_channels)
            de_relu = Activation(activation='relu')
            de_norm = BatchNorm(momentum=0.1, in_channels=outer_channels)

            if innermost:
                de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1,
                                          in_channels=inner_channels, use_bias=use_bias)
                encoder = [en_relu, en_conv]
                decoder = [de_relu, de_conv, de_norm]
                model = encoder + decoder
            elif outermost:
                de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1,
                                          in_channels=inner_channels * 2)
                encoder = [en_conv]
                decoder = [de_relu, de_conv, Activation(activation='tanh')]
                model = encoder + [inner_block] + decoder
            else:
                de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1,
                                          in_channels=inner_channels * 2, use_bias=use_bias)
                encoder = [en_relu, en_conv, en_norm]
                decoder = [de_relu, de_conv, de_norm]
                model = encoder + [inner_block] + decoder
            if use_dropout:
                model += [Dropout(rate=0.5)]

            self.model = HybridSequential()
            with self.model.name_scope():
                for block in model:
                    self.model.add(block)

    def hybrid_forward(self, F, x):
        if self.outermost:
            return self.model(x)
        else:
            return F.concat(self.model(x), x, dim=1)

# Define Unet generator
class UnetGenerator(HybridBlock):
    def __init__(self, in_channels, num_downs, ngf=64, use_dropout=True):
        super(UnetGenerator, self).__init__()

        #Build unet generator structure
        unet = UnetSkipUnit(ngf * 8, ngf * 8, innermost=True)
        for _ in range(num_downs - 5):
            unet = UnetSkipUnit(ngf * 8, ngf * 8, unet, use_dropout=use_dropout)
        unet = UnetSkipUnit(ngf * 8, ngf * 4, unet)
        unet = UnetSkipUnit(ngf * 4, ngf * 2, unet)
        unet = UnetSkipUnit(ngf * 2, ngf * 1, unet)
        unet = UnetSkipUnit(ngf, in_channels, unet, outermost=True)

        with self.name_scope():
            self.model = unet

    def hybrid_forward(self, F, x):
        return self.model(x)

# Define the PatchGAN discriminator
class Discriminator(HybridBlock):
    def __init__(self, in_channels, ndf=64, n_layers=3, use_sigmoid=False, use_bias=False):
        super(Discriminator, self).__init__()

        with self.name_scope():
            self.model = HybridSequential()
            kernel_size = 4
            padding = int(np.ceil((kernel_size - 1)/2))
            self.model.add(Conv2D(channels=ndf, kernel_size=kernel_size, strides=2,
                                  padding=padding, in_channels=in_channels))
            self.model.add(LeakyReLU(alpha=0.2))

            nf_mult = 1
            for n in range(1, n_layers):
                nf_mult_prev = nf_mult
                nf_mult = min(2 ** n, 8)
                self.model.add(Conv2D(channels=ndf * nf_mult, kernel_size=kernel_size, strides=2,
                                      padding=padding, in_channels=ndf * nf_mult_prev,
                                      use_bias=use_bias))
                self.model.add(BatchNorm(momentum=0.1, in_channels=ndf * nf_mult))
                self.model.add(LeakyReLU(alpha=0.2))

            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n_layers, 8)
            self.model.add(Conv2D(channels=ndf * nf_mult, kernel_size=kernel_size, strides=1,
                                  padding=padding, in_channels=ndf * nf_mult_prev,
                                  use_bias=use_bias))
            self.model.add(BatchNorm(momentum=0.1, in_channels=ndf * nf_mult))
            self.model.add(LeakyReLU(alpha=0.2))
            self.model.add(Conv2D(channels=1, kernel_size=kernel_size, strides=1,
                                  padding=padding, in_channels=ndf * nf_mult))
            if use_sigmoid:
                self.model.add(Activation(activation='sigmoid'))

    def hybrid_forward(self, F, x):
        out = self.model(x)
        #print(out)
        return out

def param_init(param,ctx):
    ctx = ctx
    if param.name.find('conv') != -1:
        if param.name.find('weight') != -1:
            param.initialize(init=mx.init.Normal(0.02), ctx=ctx)
        else:
            param.initialize(init=mx.init.Zero(), ctx=ctx)
    elif param.name.find('batchnorm') != -1:
        param.initialize(init=mx.init.Zero(), ctx=ctx)
        # Initialize gamma from normal distribution with mean 1 and std 0.02
        if param.name.find('gamma') != -1:
            param.set_data(nd.random_normal(1, 0.02, param.data().shape))

def network_init(net,ctx):
    ctx = ctx
    for param in net.collect_params().values():
        param_init(param,ctx)

def set_network():
    # Pixel2pixel networks
    netG = UnetGenerator(in_channels=3, num_downs=8)
    netD = Discriminator(in_channels=6)

    return netG, netD

「Train.py」の作成

import mxnet as mx
from mxnet import gluon
from mxnet import ndarray as nd
from mxnet import autograd
import numpy as np

import time
import logging

import Utility

class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        ret_imgs = []
        for i in range(images.shape[0]):
            image = nd.expand_dims(images[i], axis=0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                ret_imgs.append(image)
            else:
                p = nd.random_uniform(0, 1, shape=(1,)).asscalar()
                if p > 0.5:
                    random_id = nd.random_uniform(0, self.pool_size - 1, shape=(1,)).astype(np.uint8).asscalar()
                    tmp = self.images[random_id].copy()
                    self.images[random_id] = image
                    ret_imgs.append(tmp)
                else:
                    ret_imgs.append(image)
        ret_imgs = nd.concat(*ret_imgs, dim=0)
        return ret_imgs

def facc(label, pred):
        pred = pred.ravel()
        label = label.ravel()
        return ((pred > 0.5) == label).mean()

def train(data, netD, netG, epochs, batch_size, ctx):
    
    ################################
    ########ハイパーパラメータ#######
    pool_size = 50
    lr = 0.0002
    beta1 = 0.5
    lambda1 = 100
    ################################

    train_data = data

    netD = netD
    netG = netG
    
    trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
    trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})

    image_pool = ImagePool(pool_size)
    
    epochs = epochs
    batch_size = batch_size

    ctx = ctx

    metric = mx.metric.CustomMetric(facc)

    GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
    L1_loss = gluon.loss.L1Loss()

    logging.basicConfig(level=logging.DEBUG)

    for epoch in range(epochs):
        tic = time.time()
        btic = time.time()
        train_data.reset()
        iter = 0
        for batch in train_data:
            ############################
            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
            ###########################
            real_in = batch.data[0].as_in_context(ctx)
            real_out = batch.data[1].as_in_context(ctx)

            fake_out = netG(real_in)
            fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1))
            with autograd.record():
                # Train with fake image
                # Use image pooling to utilize history images
                output = netD(fake_concat)
                fake_label = nd.zeros(output.shape, ctx=ctx)
                errD_fake = GAN_loss(output, fake_label)
                metric.update([fake_label,], [output,])

                # Train with real image
                real_concat = nd.concat(real_in, real_out, dim=1)
                output = netD(real_concat)
                real_label = nd.ones(output.shape, ctx=ctx)
                errD_real = GAN_loss(output, real_label)
                errD = (errD_real + errD_fake) * 0.5
                errD.backward()
                metric.update([real_label,], [output,])

            trainerD.step(batch.data[0].shape[0])

            ############################
            # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z))
            ###########################
            with autograd.record():
                fake_out = netG(real_in)
                fake_concat = nd.concat(real_in, fake_out, dim=1)
                output = netD(fake_concat)
                real_label = nd.ones(output.shape, ctx=ctx)
                errG = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * lambda1
                errG.backward()

            trainerG.step(batch.data[0].shape[0])

            # Print log infomation every ten batches
            if iter % 10 == 0:
                name, acc = metric.get()
                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d'
                         %(nd.mean(errD).asscalar(),
                           nd.mean(errG).asscalar(), acc, (iter + batch_size), (epoch + 1)))
            iter = iter + 1
            btic = time.time()

        name, acc = metric.get()
        metric.reset()
        logging.info('binary training acc at epoch %d: %s=%f' % ((epoch+1), name, acc))
        logging.info('time: %f' % (time.time() - tic))

        # save one generated image for each epoch
        fake_img = fake_out[0]
        filename = 'epoch%d_out.jpg' % (epoch+1)
        Utility.save_img(fake_img,filename)

        # save parameters for each epoch
        param_filename = 'netG_%d.params' % (epoch+1)
        netG.save_parameters(param_filename)

実行(学習)

import Model
import Utility
import Train

import mxnet as mx

batch_size = 10
epochs = 30
ctx = mx.gpu()

#データの読み込み
dataset = 'facades'
train_img_path = '%s/train' % (dataset)
train_data = Utility.load_data(train_img_path, batch_size, is_reversed=True)

#ネットワークの作成
netG, netD = Model.set_network()

# Initialize parameters
Model.network_init(netG, ctx = ctx)
Model.network_init(netD, ctx = ctx)

Train.train(data = train_data, netD = netD, netG = netG, epochs = epochs, batch_size = batch_size, ctx = ctx)

実行(テスト)

import Model
import numpy as np
import mxnet as mx
from mxnet import ndarray as nd
from PIL import Image

ctx = mx.gpu()

netG, netD = Model.set_network()
netG.load_parameters('netG_30.params', ctx = ctx)

img = 'input.jpg'

img_arr = mx.image.imread(img).astype(np.float32)/127.5 - 1
img_arr = mx.image.imresize(img_arr, 256, 256)
img_arr = nd.transpose(img_arr, (2,0,1))
img_arr = nd.expand_dims(img_arr,axis = 0)

img_out = netG(img_arr.as_in_context(ctx))
img_out = ((img_out[0].asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)

pilImg = Image.fromarray(img_out)

pilImg.save('out.jpg', "JPEG", quality=100)