Gluonの「dataset」と「DataLoader」について

サンプルコード

import numpy as np
from mxnet.gluon import data

a = np.random.rand(100,3)
b = np.random.rand(100,5)
c = np.random.rand(100,8)

dataset = data.dataset.ArrayDataset(a, b, c)
data_loader = data.DataLoader(dataset, batch_size=2,shuffle=True)

for i in range(2):
    print('epoch %d'%i)
    for batch in data_loader:
        print(batch[0])
        print(batch[1])
        print(batch[2])
        break

出力

epoch 0

[[0.83138743 0.0189369  0.09041065]
 [0.7277074  0.25689951 0.99610648]]
<NDArray 2x3 @cpu(0)>

[[0.13247276 0.29411466 0.9038639  0.9184374  0.71817192]
 [0.06205099 0.26872459 0.35199091 0.10605931 0.11823489]]
<NDArray 2x5 @cpu(0)>

[[0.77203044 0.88111204 0.89137279 0.64343076 0.13709234 0.37154335
  0.90996639 0.67373178]
 [0.18116684 0.63198589 0.11531276 0.22501306 0.51462202 0.73394312
  0.99965973 0.58817562]]
<NDArray 2x8 @cpu(0)>
epoch 1

[[0.67662112 0.31150197 0.96482456]
 [0.2055873  0.07256988 0.48376131]]
<NDArray 2x3 @cpu(0)>

[[0.43729793 0.85623583 0.30163485 0.09532238 0.0142127 ]
 [0.08242432 0.24336176 0.20770583 0.78283749 0.51888601]]
<NDArray 2x5 @cpu(0)>

[[0.96016702 0.82671972 0.79739873 0.90920828 0.58451523 0.24137856
  0.85786866 0.30341021]
 [0.04190342 0.92177108 0.95736727 0.93005575 0.25801083 0.89755136
  0.75930622 0.09502917]]
<NDArray 2x8 @cpu(0)>

ここからわかること

  • datasetを作るときはnumpy.ndarrayを受け付ける(MXNetのNDArrayである必要がない)
  • 第1軸がバッチとして扱われる
  • (今回は示していないが)リストも受け付ける
  • datasetは二つ以上のデータを受け付ける
  • epoch毎にシャッフルされる

補足1

DataLoaderを実行した時点でNDArrayに変更される。
それがいやならRandomSamplerとBatchSamplerを使用すればよい。

a = [1, 2, 3, 4, 5]
b = [[2,2,2], [3,3,3], [4,4,4], [5,5,5], [6,6,6]]
dataset = data.dataset.ArrayDataset(a, b)
sampler = data.RandomSampler(5)
data_loader = data.BatchSampler(sampler, batch_size=2)

毎回シャッフルもされる。

>>> for i in range(3):
...     print('epoch %d'%i)
...     for batch in data_loader:
...         [dataset[i] for i in batch]
...
epoch 0
[(1, [2, 2, 2]), (3, [4, 4, 4])]
[(5, [6, 6, 6]), (2, [3, 3, 3])]
[(4, [5, 5, 5])]
epoch 1
[(1, [2, 2, 2]), (2, [3, 3, 3])]
[(3, [4, 4, 4]), (4, [5, 5, 5])]
[(5, [6, 6, 6])]
epoch 2
[(2, [3, 3, 3]), (3, [4, 4, 4])]
[(4, [5, 5, 5]), (1, [2, 2, 2])]
[(5, [6, 6, 6])]

または

a = [[1, [2,2,2]],[2,[3,3,3]], [3,[4,4,4]], [4,[5,5,5]],[5,[6,6,6,]]]
sampler = data.RandomSampler(len(a))
data_loader = data.BatchSampler(sampler, batch_size=2)

for i in range(3):
    print('epoch %d'%i)
    for batch in data_loader:
        print([a[i] for i in batch])

「BatchSampler」自体は「shuffle=True」の引数をとれないので、シャッフルしたい時は「RandomSampler」と一緒に使用する。

補足2

画像のクラス分類でフォルダごとに画像が分類されている時は以下も使用できる。

dataset = gluon.data.vision.ImageFolderDataset(train_path)
class_name = dataset.synsets
train_data = gluon.data.DataLoader(
    dataset,
    batch_size=batch_size, shuffle=True)

補足3

Gluon Datasets and DataLoader — mxnet documentation
こちらのチュートリアルにはこのような記載がある。

Before Gluon’s DataLoader, MXNet used DataIter objects for loading data for training and testing. 
DataIter has a similar interface for iterating through data, but it isn’t directly compatible with typical Gluon DataLoader loops. 
Unlike Gluon DataLoader which often returns a tuple of (data, label), a DataIter returns a DataBatch object that has data and label properties. 
Switching to DataLoaders is highly recommended when using Gluon, but you’ll need to take care of pre-processing steps such as augmentations in a transform function.

一部を翻訳すると、

GluonのDataLoaderの前は、MXNetはDataIterオブジェクトを使用していた。
Gluonを使用する場合は、DataLoaderへの切り替えを強くお勧めします。