バージョン確認(pip freeze)
astroid==2.2.5 certifi==2019.3.9 chardet==3.0.4 colorama==0.4.1 graphviz==0.8.4 idna==2.6 isort==4.3.20 lazy-object-proxy==1.4.1 mccabe==0.6.1 mxnet==1.4.1 numpy==1.16.4 Pillow==6.0.0 pylint==2.3.1 requests==2.18.4 six==1.12.0 typed-ast==1.4.0 urllib3==1.22 wrapt==1.11.1
RNANの学習済みモデルをダウンロード
- MXNetの学習済みモデルはこちらからダウンロード可能
github.com
「RNAN_SR_RGB_X4-symbol.json」と「RNAN_SR_RGB_X4-0000.params」をダウンロード
サンプル画像のダウンロード
『こちら』からダウンロードして「dog.jpg」の名前で保存
実行スクリプト
import numpy as np import mxnet as mx from mxnet import image ctx = mx.cpu() img = image.imread('dog.jpg') img = img.astype(np.float32)/255 img = mx.nd.transpose(img, (2,0,1)) img = mx.nd.expand_dims(img, axis=0) sym, arg_params, aux_params = mx.model.load_checkpoint('RNAN_SR_RGB_X4', 0) model = mx.mod.Module(symbol=sym, label_names=None, context=ctx) model.bind(for_training=False, data_shapes=[('data', img.shape)]) model.set_params(arg_params, aux_params) from collections import namedtuple Batch = namedtuple('Batch', ['data']) model.forward(Batch([img]), is_train=False) prob = model.get_outputs()[0].asnumpy() prob = np.squeeze(prob) from PIL import Image prob = (prob.transpose(1,2,0)*255).astype(np.uint8) img = Image.fromarray(prob) img.save('RNAN_4x.jpg')
結果の表示