DCGANではこのようなコードをよく見かける。
########################### # Update G network ########################### noise = mx.nd.random_normal(0, 1, shape=(batch_size, 100, 1, 1)) with autograd.record(): fake = netG(noise) output = netD(fake).reshape((-1, 1)) errG = loss(output, real_label) errG.backward() trainerG.step(batch_size)
「Generator」のパラメータ更新と同時に「Discriminator」の「Batch Norm」内の何かが変更される。
調べてみると「moving_mean」と「moving_variance」と言われるものらしい。
日本語で「移動平均値」と「移動分散値」。
検証したのがこちら。
- 「dcgan_model.py」
from mxnet.gluon import nn def netG(): model = nn.HybridSequential() with model.name_scope(): model.add(nn.Conv2DTranspose(512, kernel_size=4, strides=1, padding=0, use_bias=False)) model.add(nn.BatchNorm()) model.add(nn.Activation('relu')) model.add(nn.Conv2DTranspose(256, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm()) model.add(nn.Activation('relu')) model.add(nn.Conv2DTranspose(128, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm()) model.add(nn.Activation('relu')) model.add(nn.Conv2DTranspose(64, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm()) model.add(nn.Activation('relu')) model.add(nn.Conv2DTranspose(3, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.Activation('tanh')) return model def netD(): model = nn.HybridSequential() with model.name_scope(): model.add(nn.Conv2D(64, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm()) model.add(nn.LeakyReLU(alpha=0.2)) model.add(nn.Conv2D(128, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm()) model.add(nn.LeakyReLU(alpha=0.2)) model.add(nn.Conv2D(256, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm()) model.add(nn.LeakyReLU(alpha=0.2)) model.add(nn.Conv2D(512, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm()) model.add(nn.LeakyReLU(alpha=0.2)) model.add(nn.Conv2D(1, kernel_size=4, strides=1, padding=0, use_bias=False)) return model
- 変更されたことを検証するためのコード
import mxnet as mx import dcgan_model from mxnet import gluon, autograd netG = dcgan_model.netG() netG.initialize(mx.init.Normal(0.02)) netD = dcgan_model.netD() netD.initialize(mx.init.Normal(0.02)) loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() trainerG = gluon.Trainer(netG.collect_params(), 'adam') batch_size = 16 real_label = mx.nd.ones((batch_size,)) input = mx.nd.random.uniform(-1,1,shape=(10,3,64,64)) output_pre = netD(input) ########################### # Update G network ########################### noise = mx.nd.random_normal(0, 1, shape=(batch_size, 100, 1, 1)) with autograd.record(): fake = netG(noise) output = netD(fake).reshape((-1, 1)) errG = loss(output, real_label) errG.backward() trainerG.step(batch_size) output_post = netD(input) import numpy as np print(np.array_equal(output_pre.asnumpy(), output_post.asnumpy()))
- 結果
False
- 実際に変更された値を確認
import mxnet as mx import dcgan_model from mxnet import gluon, autograd netG = dcgan_model.netG() netG.initialize(mx.init.Normal(0.02)) netD = dcgan_model.netD() netD.initialize(mx.init.Normal(0.02)) loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() trainerG = gluon.Trainer(netG.collect_params(), 'adam') batch_size = 8 real_label = mx.nd.ones((batch_size,)) dummy = mx.nd.random.uniform(0,10,shape=(10,3,64,64)) output_pre = netD(dummy) print( netD.collect_params()['hybridsequential1_batchnorm0_running_mean'].data() ) print( netD.collect_params()['hybridsequential1_batchnorm0_running_var'].data() ) ########################### # Update G network ########################### for i in range(2): noise = mx.nd.random_normal(0, 1, shape=(batch_size, 100, 1, 1)) with autograd.record(): fake = netG(noise) output = netD(fake).reshape((-1, 1)) errG = loss(output, real_label) errG.backward() trainerG.step(batch_size) print( netD.collect_params()['hybridsequential1_batchnorm0_running_mean'].data() ) print( netD.collect_params()['hybridsequential1_batchnorm0_running_var'].data() )
- 結果
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] <NDArray 64 @cpu(0)> [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] <NDArray 64 @cpu(0)> [ 9.35047574e-05 9.33420612e-04 -5.87339862e-04 6.20232022e-04 -7.31294102e-04 1.73414650e-04 3.82519444e-04 4.33211884e-04 1.44761393e-03 1.08334680e-04 -1.53198547e-03 -1.20464785e-04 -7.91981700e-04 -8.36857711e-04 -1.06310612e-03 1.93330599e-03 -6.75409450e-04 -7.49928267e-06 -2.00461177e-03 -1.60970632e-03 2.40422855e-03 1.48091814e-04 -3.40931281e-03 4.04094957e-04 1.24554336e-03 -4.83149372e-04 1.15514221e-03 3.16894584e-04 1.04401552e-03 -9.56838543e-04 3.67905508e-04 -1.60453713e-03 2.20237323e-03 1.10279478e-04 -5.06307930e-04 6.68685418e-04 -4.18046955e-04 -7.73620966e-04 -2.43184186e-04 8.01519491e-04 -1.79696144e-05 -1.94683496e-03 1.41503895e-03 -1.32936612e-03 8.90188210e-04 -1.92038203e-03 -2.20242376e-03 1.02425751e-03 1.08564377e-03 -3.88558052e-04 -5.89201227e-04 2.11480117e-04 5.23960218e-04 -1.19373817e-05 3.39562190e-04 1.55972119e-03 -8.74583609e-04 2.33166604e-04 4.00208810e-04 -5.26324671e-04 -1.19255355e-03 -8.04555137e-04 -2.55955569e-03 2.99032981e-04] <NDArray 64 @cpu(0)> [0.90006787 0.9000666 0.9000721 0.90005267 0.9000545 0.90002877 0.9000607 0.9000617 0.9000472 0.9000413 0.9000968 0.9000484 0.90005195 0.90007013 0.900058 0.9000561 0.9000638 0.9000553 0.90004385 0.9000528 0.90003955 0.90008193 0.9000904 0.900074 0.9000592 0.90004903 0.9000607 0.90007347 0.9000439 0.90007645 0.9000437 0.90006006 0.90006393 0.9000693 0.9000455 0.9000424 0.9000393 0.9000585 0.900062 0.9000492 0.9000406 0.9000627 0.9000363 0.90005255 0.9000594 0.900075 0.90005773 0.90006167 0.90006816 0.9000661 0.90006274 0.9000439 0.90005606 0.9000475 0.90005296 0.9000594 0.90006024 0.90005535 0.9000738 0.9000466 0.90005 0.9000415 0.90006083 0.9000567 ] <NDArray 64 @cpu(0)>
「BatchNorm」に「use_global_stats=True」を記述すると変更されない。
def netD(): model = nn.HybridSequential() with model.name_scope(): model.add(nn.Conv2D(64, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm(use_global_stats=True)) model.add(nn.LeakyReLU(alpha=0.2)) model.add(nn.Conv2D(128, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm(use_global_stats=True)) model.add(nn.LeakyReLU(alpha=0.2)) model.add(nn.Conv2D(256, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm(use_global_stats=True)) model.add(nn.LeakyReLU(alpha=0.2)) model.add(nn.Conv2D(512, kernel_size=4, strides=2, padding=1, use_bias=False)) model.add(nn.BatchNorm(use_global_stats=True)) model.add(nn.LeakyReLU(alpha=0.2)) model.add(nn.Conv2D(1, kernel_size=4, strides=1, padding=0, use_bias=False)) return model