MXNetで超解像(GPUなし)

2018年7月26日動作確認

環境

Windows10 Pro 64bit

はじめに(注意)

オリジナルではありません。
GitHubで公開されているものを少しいじっただけです。
github.com

Anacondaで仮想環境を作成

conda create -n onnx python=3.6 anaconda
activate onnx
  • pipのアップデート
python -m pip install --upgrade pip
  • msgpackのインストール
pip install msgpack

MXNetとONNXのインストール

pip install mxnet==1.3.0b20180724
pip isntall onnx

学習済みモデルのダウンロード

  • modelフォルダを新規作成(現在のディレクトリ直下)
  • こちら』から学習済みモデルをダウンロードしてmodelフォルダに入れる

サンプル画像のダウンロード

実行ファイルの記述

from __future__ import absolute_import as _abs
from __future__ import print_function
from collections import namedtuple
import logging
import numpy as np
from PIL import Image
import mxnet as mx
from mxnet.test_utils import download
import mxnet.contrib.onnx as onnx_mxnet
import argparse

# set up logger
logging.basicConfig()
LOGGER = logging.getLogger()
LOGGER.setLevel(logging.INFO)

parser = argparse.ArgumentParser()
parser.add_argument("--image",type=str,required=True,
                    help="path to image")

def import_onnx():
    """Import the onnx model into mxnet"""
    sym, arg_params, aux_params = onnx_mxnet.import_model('model/super_resolution.onnx')
    LOGGER.info("Successfully Converted onnx format to mxnet's symbol and params...")
    return sym, arg_params, aux_params

def get_test_image(file_name):
    input_image_dim = 224
    img = Image.open(file_name).resize((input_image_dim, input_image_dim))
    img_ycbcr = img.convert("YCbCr")
    img_y, img_cb, img_cr = img_ycbcr.split()
    input_image = np.array(img_y)[np.newaxis, np.newaxis, :, :]
    return input_image, img_cb, img_cr

def perform_inference(sym, arg_params, aux_params, input_img, img_cb, img_cr):
    """Perform inference on image using mxnet"""
    metadata = onnx_mxnet.get_model_metadata('model/super_resolution.onnx')
    data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')]
    # create module
    mod = mx.mod.Module(symbol=sym, data_names=data_names, label_names=None)
    mod.bind(for_training=False, data_shapes=[(data_names[0], input_img.shape)])
    mod.set_params(arg_params=arg_params, aux_params=aux_params)

    # run inference
    batch = namedtuple('Batch', ['data'])
    mod.forward(batch([mx.nd.array(input_img)]))

    # Save the result
    img_out_y = Image.fromarray(np.uint8(mod.get_outputs()[0][0][0].
                                         asnumpy().clip(0, 255)), mode='L')

    result_img = Image.merge(
        "YCbCr", [img_out_y,
                  img_cb.resize(img_out_y.size, Image.BICUBIC),
                  img_cr.resize(img_out_y.size, Image.BICUBIC)]).convert("RGB")
    output_img_dim = 672
    assert result_img.size == (output_img_dim, output_img_dim)
    LOGGER.info("Super Resolution example success.")
    result_img.save("super_res_output.jpg")
    return result_img

if __name__ == '__main__':
    args = parser.parse_args()
    MX_SYM, MX_ARG_PARAM, MX_AUX_PARAM = import_onnx()
    INPUT_IMG, IMG_CB, IMG_CR = get_test_image(args.image)
    perform_inference(MX_SYM, MX_ARG_PARAM, MX_AUX_PARAM, INPUT_IMG, IMG_CB, IMG_CR)

実行

python super_res.py --image super_res_input.jpg