はじめに
- Enhanced SRGAN (ESRGAN)、RRDBについてはこちらを参照
[1809.00219] ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
- MXNetの学習済みモデルはこちらからダウンロード可能
サンプル画像のダウンロード
『こちら』からダウンロードして「dog.jpg」の名前で保存
実行スクリプト
import numpy as np import mxnet as mx from mxnet import image ctx = mx.cpu() img = image.imread('dog.jpg') img = img.astype(np.float32)/255 img = mx.nd.transpose(img, (2,0,1)) img = mx.nd.expand_dims(img, axis=0) sym, arg_params, aux_params = mx.model.load_checkpoint('ESRGAN_4x', 0) model = mx.mod.Module(symbol=sym, label_names=None, context=ctx) model.bind(for_training=False, data_shapes=[('data', img.shape)]) model.set_params(arg_params, aux_params) from collections import namedtuple Batch = namedtuple('Batch', ['data']) model.forward(Batch([img]), is_train=False) prob = model.get_outputs()[0].asnumpy() prob = np.squeeze(prob) from PIL import Image prob = (prob.transpose(1,2,0)*255).astype(np.uint8) img = Image.fromarray(prob) img.save('ESRGAN_4x.jpg')
import numpy as np import mxnet as mx from mxnet import image ctx = mx.cpu() img = image.imread('dog.jpg') img = img.astype(np.float32)/255 img = mx.nd.transpose(img, (2,0,1)) img = mx.nd.expand_dims(img, axis=0) sym, arg_params, aux_params = mx.model.load_checkpoint('RRDB_4x', 0) model = mx.mod.Module(symbol=sym, label_names=None, context=ctx) model.bind(for_training=False, data_shapes=[('data', img.shape)]) model.set_params(arg_params, aux_params) from collections import namedtuple Batch = namedtuple('Batch', ['data']) model.forward(Batch([img]), is_train=False) prob = model.get_outputs()[0].asnumpy() prob = np.squeeze(prob) from PIL import Image prob = (prob.transpose(1,2,0)*255).astype(np.uint8) img = Image.fromarray(prob) img.save('RRDB_4x.jpg')
結果の表示
- 上から「元画像」「ESRGAN」「RRDB」
- 「ESRGAN」はうまくいっていない
from PIL import Image original = Image.open('dog.jpg').resize((1024,1024)) x0 = original.crop((220,200,850,650)) esrgan = Image.open('ESRGAN_4x.jpg') x1 = esrgan.crop((220,200,850,650)) rrdb = Image.open('RRDB_4x.jpg') x2 = rrdb.crop((220,200,850,650)) img = Image.new('RGB', (630, 450*3)) img.paste(x0, (0, 0)) img.paste(x1, (0, 450)) img.paste(x2, (0, 900)) img.save('result2.jpg')
2018年10月18日追記
- 学習済みモデルはGluonでロードできた
import numpy as np import mxnet as mx from mxnet import image, gluon img = image.imread('dog.jpg') img = img.astype(np.float32)/255 img = mx.nd.transpose(img, (2,0,1)) img = mx.nd.expand_dims(img, axis=0) net = gluon.nn.SymbolBlock.imports("ESRGAN_4x-symbol.json", ['data'], "ESRGAN_4x-0000.params") output = net(img) output = mx.nd.squeeze(output) output = output.asnumpy() output = (output.transpose(1,2,0)*255).astype(np.uint8) from PIL import Image img = Image.fromarray(output) img.save('ESRGAN_4x.jpg')
import numpy as np import mxnet as mx from mxnet import image, gluon img = image.imread('dog.jpg') img = img.astype(np.float32)/255 img = mx.nd.transpose(img, (2,0,1)) img = mx.nd.expand_dims(img, axis=0) net = gluon.nn.SymbolBlock.imports("RRDB_4x-symbol.json", ['data'], "RRDB_4x-0000.params") output = net(img) output = mx.nd.squeeze(output) output = output.asnumpy() output = (output.transpose(1,2,0)*255).astype(np.uint8) from PIL import Image img = Image.fromarray(output) img.save('RRDB_4x.jpg')