MXNet 備忘録(3)

「mx.io.NDArrayIter」の仕様が変わったらしい
github.com

Change the timing of shuffling. Previously, it shuffles only during the initialization, which didn't meet training needs.
・Changes
shuffle when calling the reset.

「shuffle」について

テスト用コード

import mxnet as mx

A = []

for i in range(6):
    a = mx.nd.array([i])
    a = mx.nd.expand_dims(a, axis=0)
    A.append(a)

batch_size = 3

train_data = mx.io.NDArrayIter(data=[mx.nd.concat(*A, dim=0)], 
                             batch_size=batch_size, 
                             shuffle=True)

epoch = 2
for i in range(1, epoch+1):
    print('epoch %d'%i)
    train_data.reset()
    for batch in train_data:
        print(batch.data[0])

結果

epoch 1

[[5.]
 [4.]
 [1.]]
<NDArray 3x1 @cpu(0)>

[[3.]
 [2.]
 [0.]]
<NDArray 3x1 @cpu(0)>
epoch 2

[[0.]
 [3.]
 [4.]]
<NDArray 3x1 @cpu(0)>

[[1.]
 [5.]
 [2.]]
<NDArray 3x1 @cpu(0)>

reset毎にshuffleが実行される

「last_batch_handle」について

  • 「last_batch_handle='discard'」

テスト用コード

import mxnet as mx

A = []

for i in range(5):
    a = mx.nd.array([i])
    a = mx.nd.expand_dims(a, axis=0)
    A.append(a)

batch_size = 3

train_data = mx.io.NDArrayIter(data=[mx.nd.concat(*A, dim=0)], 
                             batch_size=batch_size, 
                             shuffle=False,
                             last_batch_handle='discard')

epoch = 2
for i in range(1, epoch+1):
    print('epoch %d'%i)
    train_data.reset()
    for batch in train_data:
        print(batch.data[0])

結果

epoch 1

[[0.]
 [1.]
 [2.]]
<NDArray 3x1 @cpu(0)>
epoch 2

[[0.]
 [1.]
 [2.]]
<NDArray 3x1 @cpu(0)>

切り捨てられる

  • 「last_batch_handle='pad'」(デフォルト)

テスト用コード

import mxnet as mx

A = []

for i in range(5):
    a = mx.nd.array([i])
    a = mx.nd.expand_dims(a, axis=0)
    A.append(a)

batch_size = 3

train_data = mx.io.NDArrayIter(data=[mx.nd.concat(*A, dim=0)], 
                             batch_size=batch_size, 
                             shuffle=False,
                             last_batch_handle='pad')

epoch = 2
for i in range(1, epoch+1):
    print('epoch %d'%i)
    train_data.reset()
    for batch in train_data:
        print(batch.data[0])

結果

epoch 1

[[0.]
 [1.]
 [2.]]
<NDArray 3x1 @cpu(0)>

[[3.]
 [4.]
 [0.]]
<NDArray 3x1 @cpu(0)>
epoch 2

[[0.]
 [1.]
 [2.]]
<NDArray 3x1 @cpu(0)>

[[3.]
 [4.]
 [0.]]
<NDArray 3x1 @cpu(0)>

最初に戻ってPaddingされる

  • 「last_batch_handle='roll_over'」

テスト用コード

import mxnet as mx

A = []

for i in range(5):
    a = mx.nd.array([i])
    a = mx.nd.expand_dims(a, axis=0)
    A.append(a)

batch_size = 3

train_data = mx.io.NDArrayIter(data=[mx.nd.concat(*A, dim=0)], 
                             batch_size=batch_size, 
                             shuffle=False,
                             last_batch_handle='roll_over')

epoch = 2
for i in range(1, epoch+1):
    print('epoch %d'%i)
    train_data.reset()
    for batch in train_data:
        print(batch.data[0])

結果

epoch 1

[[0.]
 [1.]
 [2.]]
<NDArray 3x1 @cpu(0)>
epoch 2

[[3.]
 [4.]
 [0.]]
<NDArray 3x1 @cpu(0)>

[[1.]
 [2.]
 [3.]]
<NDArray 3x1 @cpu(0)>

次のエポックに持ち越される