GANにおけるパラメータの更新について

DCGANではこのようなコードをよく見かける。

###########################
# Update G network
###########################
noise = mx.nd.random_normal(0, 1, shape=(batch_size, 100, 1, 1))
with autograd.record():
    fake = netG(noise)
    output = netD(fake).reshape((-1, 1))
    errG = loss(output, real_label)
    errG.backward()
trainerG.step(batch_size)

「Generator」のパラメータ更新と同時に「Discriminator」の「Batch Norm」内の何かが変更される。
調べてみると「moving_mean」と「moving_variance」と言われるものらしい。
日本語で「移動平均値」と「移動分散値」。
検証したのがこちら。

  • 「dcgan_model.py」
from mxnet.gluon import nn

def netG():
    model = nn.HybridSequential()
    with model.name_scope():
        model.add(nn.Conv2DTranspose(512, kernel_size=4, strides=1, padding=0, use_bias=False))
        model.add(nn.BatchNorm())
        model.add(nn.Activation('relu')) 
        model.add(nn.Conv2DTranspose(256, kernel_size=4, strides=2, padding=1, use_bias=False)) 
        model.add(nn.BatchNorm())
        model.add(nn.Activation('relu'))
        model.add(nn.Conv2DTranspose(128, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.BatchNorm())
        model.add(nn.Activation('relu'))
        model.add(nn.Conv2DTranspose(64, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.BatchNorm())
        model.add(nn.Activation('relu'))
        model.add(nn.Conv2DTranspose(3, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.Activation('tanh'))
    return model

def netD():
    model = nn.HybridSequential()
    with model.name_scope():
        model.add(nn.Conv2D(64, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.BatchNorm())
        model.add(nn.LeakyReLU(alpha=0.2))
        model.add(nn.Conv2D(128, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.BatchNorm())
        model.add(nn.LeakyReLU(alpha=0.2))
        model.add(nn.Conv2D(256, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.BatchNorm())
        model.add(nn.LeakyReLU(alpha=0.2))
        model.add(nn.Conv2D(512, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.BatchNorm())
        model.add(nn.LeakyReLU(alpha=0.2))
        model.add(nn.Conv2D(1, kernel_size=4, strides=1, padding=0, use_bias=False))
    return model
  • 変更されたことを検証するためのコード
import mxnet as mx
import dcgan_model
from mxnet import gluon, autograd

netG = dcgan_model.netG()
netG.initialize(mx.init.Normal(0.02))
netD = dcgan_model.netD()
netD.initialize(mx.init.Normal(0.02))

loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()

trainerG = gluon.Trainer(netG.collect_params(), 'adam')

batch_size = 16

real_label = mx.nd.ones((batch_size,))

input = mx.nd.random.uniform(-1,1,shape=(10,3,64,64))

output_pre = netD(input)

###########################
# Update G network
###########################
noise = mx.nd.random_normal(0, 1, shape=(batch_size, 100, 1, 1))
with autograd.record():
    fake = netG(noise)
    output = netD(fake).reshape((-1, 1))
    errG = loss(output, real_label)
    errG.backward()
trainerG.step(batch_size)

output_post = netD(input)

import numpy as np 
print(np.array_equal(output_pre.asnumpy(), output_post.asnumpy()))
  • 結果
False
  • 実際に変更された値を確認
import mxnet as mx
import dcgan_model
from mxnet import gluon, autograd

netG = dcgan_model.netG()
netG.initialize(mx.init.Normal(0.02))
netD = dcgan_model.netD()
netD.initialize(mx.init.Normal(0.02))

loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()

trainerG = gluon.Trainer(netG.collect_params(), 'adam')

batch_size = 8

real_label = mx.nd.ones((batch_size,))

dummy = mx.nd.random.uniform(0,10,shape=(10,3,64,64))

output_pre = netD(dummy)

print(
netD.collect_params()['hybridsequential1_batchnorm0_running_mean'].data()
)

print(
netD.collect_params()['hybridsequential1_batchnorm0_running_var'].data()
)

###########################
# Update G network
###########################
for i in range(2):
    noise = mx.nd.random_normal(0, 1, shape=(batch_size, 100, 1, 1))
    with autograd.record():
        fake = netG(noise)
        output = netD(fake).reshape((-1, 1))
        errG = loss(output, real_label)
        errG.backward()
    trainerG.step(batch_size)

print(
netD.collect_params()['hybridsequential1_batchnorm0_running_mean'].data()
)

print(
netD.collect_params()['hybridsequential1_batchnorm0_running_var'].data()
)
  • 結果
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
<NDArray 64 @cpu(0)>

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
<NDArray 64 @cpu(0)>

[ 9.35047574e-05  9.33420612e-04 -5.87339862e-04  6.20232022e-04
 -7.31294102e-04  1.73414650e-04  3.82519444e-04  4.33211884e-04
  1.44761393e-03  1.08334680e-04 -1.53198547e-03 -1.20464785e-04
 -7.91981700e-04 -8.36857711e-04 -1.06310612e-03  1.93330599e-03
 -6.75409450e-04 -7.49928267e-06 -2.00461177e-03 -1.60970632e-03
  2.40422855e-03  1.48091814e-04 -3.40931281e-03  4.04094957e-04
  1.24554336e-03 -4.83149372e-04  1.15514221e-03  3.16894584e-04
  1.04401552e-03 -9.56838543e-04  3.67905508e-04 -1.60453713e-03
  2.20237323e-03  1.10279478e-04 -5.06307930e-04  6.68685418e-04
 -4.18046955e-04 -7.73620966e-04 -2.43184186e-04  8.01519491e-04
 -1.79696144e-05 -1.94683496e-03  1.41503895e-03 -1.32936612e-03
  8.90188210e-04 -1.92038203e-03 -2.20242376e-03  1.02425751e-03
  1.08564377e-03 -3.88558052e-04 -5.89201227e-04  2.11480117e-04
  5.23960218e-04 -1.19373817e-05  3.39562190e-04  1.55972119e-03
 -8.74583609e-04  2.33166604e-04  4.00208810e-04 -5.26324671e-04
 -1.19255355e-03 -8.04555137e-04 -2.55955569e-03  2.99032981e-04]
<NDArray 64 @cpu(0)>

[0.90006787 0.9000666  0.9000721  0.90005267 0.9000545  0.90002877
 0.9000607  0.9000617  0.9000472  0.9000413  0.9000968  0.9000484
 0.90005195 0.90007013 0.900058   0.9000561  0.9000638  0.9000553
 0.90004385 0.9000528  0.90003955 0.90008193 0.9000904  0.900074
 0.9000592  0.90004903 0.9000607  0.90007347 0.9000439  0.90007645
 0.9000437  0.90006006 0.90006393 0.9000693  0.9000455  0.9000424
 0.9000393  0.9000585  0.900062   0.9000492  0.9000406  0.9000627
 0.9000363  0.90005255 0.9000594  0.900075   0.90005773 0.90006167
 0.90006816 0.9000661  0.90006274 0.9000439  0.90005606 0.9000475
 0.90005296 0.9000594  0.90006024 0.90005535 0.9000738  0.9000466
 0.90005    0.9000415  0.90006083 0.9000567 ]
<NDArray 64 @cpu(0)>

「BatchNorm」に「use_global_stats=True」を記述すると変更されない。

def netD():
    model = nn.HybridSequential()
    with model.name_scope():
        model.add(nn.Conv2D(64, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.BatchNorm(use_global_stats=True))
        model.add(nn.LeakyReLU(alpha=0.2))
        model.add(nn.Conv2D(128, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.BatchNorm(use_global_stats=True))
        model.add(nn.LeakyReLU(alpha=0.2))
        model.add(nn.Conv2D(256, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.BatchNorm(use_global_stats=True))
        model.add(nn.LeakyReLU(alpha=0.2))
        model.add(nn.Conv2D(512, kernel_size=4, strides=2, padding=1, use_bias=False))
        model.add(nn.BatchNorm(use_global_stats=True))
        model.add(nn.LeakyReLU(alpha=0.2))
        model.add(nn.Conv2D(1, kernel_size=4, strides=1, padding=0, use_bias=False))
    return model