短くかつわかりやすくMNISTを解く

今更感があるがMXNetのコードの書き方がだんだんわかってきたので、できるだけ短くそしてわかりやすく書いてみた。

import mxnet as mx
from mxnet import gluon, autograd
from mxnet.gluon import nn, rnn, data

mnist = mx.test_utils.get_mnist()
x_train = mnist['train_data']
t_train = mnist['train_label']
x_test = mnist['test_data']
t_test = mnist['test_label']

x_train = x_train.reshape(-1,28,28)
x_test = x_test.reshape(-1,28,28)

#DataLoaderを実行するとNDArrayに変換される
dataset = data.dataset.ArrayDataset(x_train, t_train)
data_loader = data.DataLoader(dataset, batch_size=100,shuffle=True)

x_test = mx.nd.array(x_test)
t_test = mx.nd.array(t_test)

net = nn.HybridSequential()
with net.name_scope():
    net.add(rnn.LSTM(128, layout='NTC'))
    net.add(nn.Dense(10))
net.initialize(mx.init.Xavier())
net.hybridize()

def evaluate_accuracy(input, label, net):
    output = net(input)
    predictions = mx.nd.argmax(output, axis=1)
    acc = sum(predictions==label) / input.shape[0]
    return acc.asscalar()

loss_func = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(), 'adam')

print('start training...')
epochs = 10
loss_n = [] #ログ表示用

for epoch in range(1, epochs + 1):
    for batch in data_loader:
        data = batch[0]
        label = batch[1]
        #ニューラルネットワークの順伝播
        with autograd.record():
            output = net(data)
            #損失を求める
            loss = loss_func(output, label)
            #ログ表示用に損失の値を保存
            loss_n.append(sum(loss).asscalar())
            #損失の値から逆伝播する
            loss.backward()
        #学習ステータスをデータサイズ分進める
        trainer.step(data.shape[0])
    #ログを表示
    ll = sum(loss_n)
    test_acc = evaluate_accuracy(x_test, t_test, net)
    
    print('%d epoch loss = %f test_acc = %f' %(epoch, ll, test_acc))
    loss_n = []

net.save_parameters('lstm.params')

結果

start training...
1 epoch loss = 16976.946788 test_acc = 0.968500
2 epoch loss = 5060.839172 test_acc = 0.976700
3 epoch loss = 3500.368934 test_acc = 0.981900
4 epoch loss = 2642.134119 test_acc = 0.984200
5 epoch loss = 2146.035752 test_acc = 0.988000
6 epoch loss = 1826.101408 test_acc = 0.988100
7 epoch loss = 1540.338854 test_acc = 0.988600
8 epoch loss = 1302.081520 test_acc = 0.986600
9 epoch loss = 1164.962564 test_acc = 0.987500
10 epoch loss = 1034.104569 test_acc = 0.990400
  • lossを標本数で割っていないので大きい値になっている。
  • テストデータの正解率は99%