MXNet の NumPy Interface を使ってみる

2022年3月16日記事を更新しました。

はじめに

Numpy Interfaceについてはこちらを参照して下さい。
今回は画像分類を行いました。

環境

二つの環境で動作確認しました。

Windows 10
NVIDIA GeForce GTX1080
CUDA Toolkit 10.1
Python 3.7.9

mxnet-cu101==1.7.0

Windows10
GPUなし
Python 3.8.6

mxnet==1.7.0.post1

mxnetのインストール方法はこちらを参照して下さい。

Pythonスクリプト

from mxnet import np, npx, gluon, image
npx.set_np()

ctx = npx.gpu() if npx.num_gpus() > 0 else npx.cpu()

url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b5/Golden_Retriever_medium-to-light-coat.jpg/365px-Golden_Retriever_medium-to-light-coat.jpg'
fname = gluon.utils.download(url)
x = image.imread(fname)

url = 'http://data.mxnet.io/models/imagenet/synset.txt'
fname = gluon.utils.download(url)
with open(fname, 'r') as f:
    text_labels = [' '.join(l.split()[1:]) for l in f]

net = gluon.model_zoo.vision.resnet50_v2(pretrained=True, ctx=ctx)

x = image.resize_short(x, 256)
x, _ = image.center_crop(x, (224,224))

def transform(data):
    data = np.expand_dims(np.transpose(data, (2,0,1)), axis=0)
    rgb_mean = np.array([0.485, 0.456, 0.406]).reshape((1,3,1,1))
    rgb_std = np.array([0.229, 0.224, 0.225]).reshape((1,3,1,1))
    return (data.astype('float32') / 255 - rgb_mean) / rgb_std

prob = npx.softmax(net(transform(x).as_in_context(ctx)))

idx = npx.topk(prob, k=5)[0]

for i in idx:
    print('With prob = %.5f, it contains %s' % (
        prob[0, int(i)], text_labels[int(i)]))

結果

f:id:touch-sp:20201006130930j:plain:w250

With prob = 0.98240, it contains golden retriever
With prob = 0.00809, it contains English setter
With prob = 0.00262, it contains Irish setter, red setter
With prob = 0.00223, it contains cocker spaniel, English cocker spaniel, cocker
With prob = 0.00177, it contains Labrador retriever

その他

WSL2上のUbuntuでも実行可能でした。

環境1

Ubuntu 18.04
Python 3.7.5

certifi==2020.6.20
chardet==3.0.4    
graphviz==0.8.4   
idna==2.10        
mxnet==2.0.0b20201005
numpy==1.19.2        
pkg-resources==0.0.0 
requests==2.24.0     
urllib3==1.25.10

環境2

Ubuntu 18.04
Python 3.6.9

certifi==2020.12.5
chardet==4.0.0
contextvars==2.4
graphviz==0.8.4
idna==2.10
immutables==0.15
mxnet==2.0.0b20210219
numpy==1.19.5
pkg-resources==0.0.0
requests==2.25.1
urllib3==1.26.3

環境3

Ubuntu 20.04
Python 3.8.10

certifi==2021.10.8
charset-normalizer==2.0.6
graphviz==0.8.4
idna==3.2
mxnet==2.0.0b20211010
numpy==1.21.2
pkg_resources==0.0.0
requests==2.26.0
urllib3==1.26.7

環境4

Ubuntu 20.04
Python 3.8.10

certifi==2021.10.8
charset-normalizer==2.0.12
graphviz==0.8.4
idna==3.3
mxnet-cu112==2.0.0b20220315
numpy==1.22.3
pkg_resources==0.0.0
requests==2.27.1
urllib3==1.26.8

mxnet==2.0.0の比較的新しいベータ版を使用すると「npx.set_np()」は不要です。