Gluonの「dataset」と「DataLoader」について

サンプルコード

import numpy as np
from mxnet.gluon import data

a = np.random.rand(100,3)
b = np.random.rand(100,5)
c = np.random.rand(100,8)

dataset = data.dataset.ArrayDataset(a, b, c)
data_loader = data.DataLoader(dataset, batch_size=2,shuffle=True)

for i in range(2):
    print('epoch %d'%i)
    for batch in data_loader:
        print(batch[0])
        print(batch[1])
        print(batch[2])
        break

出力

epoch 0

[[0.83138743 0.0189369  0.09041065]
 [0.7277074  0.25689951 0.99610648]]
<NDArray 2x3 @cpu(0)>

[[0.13247276 0.29411466 0.9038639  0.9184374  0.71817192]
 [0.06205099 0.26872459 0.35199091 0.10605931 0.11823489]]
<NDArray 2x5 @cpu(0)>

[[0.77203044 0.88111204 0.89137279 0.64343076 0.13709234 0.37154335
  0.90996639 0.67373178]
 [0.18116684 0.63198589 0.11531276 0.22501306 0.51462202 0.73394312
  0.99965973 0.58817562]]
<NDArray 2x8 @cpu(0)>
epoch 1

[[0.67662112 0.31150197 0.96482456]
 [0.2055873  0.07256988 0.48376131]]
<NDArray 2x3 @cpu(0)>

[[0.43729793 0.85623583 0.30163485 0.09532238 0.0142127 ]
 [0.08242432 0.24336176 0.20770583 0.78283749 0.51888601]]
<NDArray 2x5 @cpu(0)>

[[0.96016702 0.82671972 0.79739873 0.90920828 0.58451523 0.24137856
  0.85786866 0.30341021]
 [0.04190342 0.92177108 0.95736727 0.93005575 0.25801083 0.89755136
  0.75930622 0.09502917]]
<NDArray 2x8 @cpu(0)>

ここからわかること

  • datasetを作るときはnumpy.ndarrayを受け付ける(MXNetのNDArrayである必要がない)
  • 第1軸がバッチとして扱われる
  • (今回は示していないが)リストも受け付ける
  • datasetは二つ以上のデータを受け付ける
  • epoch毎にシャッフルされる