MXNetを使ってRNANで超解像

環境

Windows10 Pro 64bit (GPUなし)
Python 3.6.8

バージョン確認(pip freeze)

astroid==2.2.5
certifi==2019.3.9
chardet==3.0.4
colorama==0.4.1
graphviz==0.8.4
idna==2.6
isort==4.3.20
lazy-object-proxy==1.4.1
mccabe==0.6.1
mxnet==1.4.1
numpy==1.16.4
Pillow==6.0.0
pylint==2.3.1
requests==2.18.4
six==1.12.0
typed-ast==1.4.0
urllib3==1.22
wrapt==1.11.1

RNANの学習済みモデルをダウンロード

  • MXNetの学習済みモデルはこちらからダウンロード可能

github.com
「RNAN_SR_RGB_X4-symbol.json」と「RNAN_SR_RGB_X4-0000.params」をダウンロード

サンプル画像のダウンロード

こちら』からダウンロードして「dog.jpg」の名前で保存
https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg

実行スクリプト

import numpy as np
import mxnet as mx
from mxnet import image

ctx = mx.cpu()

img = image.imread('dog.jpg')
img = img.astype(np.float32)/255
img = mx.nd.transpose(img, (2,0,1))
img = mx.nd.expand_dims(img, axis=0)

sym, arg_params, aux_params = mx.model.load_checkpoint('RNAN_SR_RGB_X4', 0)
model = mx.mod.Module(symbol=sym, label_names=None, context=ctx)
model.bind(for_training=False, data_shapes=[('data', img.shape)])

model.set_params(arg_params, aux_params)

from collections import namedtuple
Batch = namedtuple('Batch', ['data'])

model.forward(Batch([img]), is_train=False)
prob = model.get_outputs()[0].asnumpy()
prob = np.squeeze(prob)

from PIL import Image
prob = (prob.transpose(1,2,0)*255).astype(np.uint8)
img = Image.fromarray(prob)
img.save('RNAN_4x.jpg')

結果の表示

f:id:touch-sp:20190605192018j:plain