MXNetを使ってRRDBで超解像

はじめに

  • Enhanced SRGAN (ESRGAN)、RRDBについてはこちらを参照

[1809.00219] ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks

  • MXNetの学習済みモデルはこちらからダウンロード可能

github.com

サンプル画像のダウンロード

こちら』からダウンロードして「dog.jpg」の名前で保存
https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg

実行スクリプト

import numpy as np
import mxnet as mx
from mxnet import image

ctx = mx.cpu()

img = image.imread('dog.jpg')
img = img.astype(np.float32)/255
img = mx.nd.transpose(img, (2,0,1))
img = mx.nd.expand_dims(img, axis=0)

sym, arg_params, aux_params = mx.model.load_checkpoint('ESRGAN_4x', 0)
model = mx.mod.Module(symbol=sym, label_names=None, context=ctx)
model.bind(for_training=False, data_shapes=[('data', img.shape)])

model.set_params(arg_params, aux_params)

from collections import namedtuple
Batch = namedtuple('Batch', ['data'])

model.forward(Batch([img]), is_train=False)
prob = model.get_outputs()[0].asnumpy()
prob = np.squeeze(prob)

from PIL import Image
prob = (prob.transpose(1,2,0)*255).astype(np.uint8)
img = Image.fromarray(prob)
img.save('ESRGAN_4x.jpg')
import numpy as np
import mxnet as mx
from mxnet import image

ctx = mx.cpu()

img = image.imread('dog.jpg')
img = img.astype(np.float32)/255
img = mx.nd.transpose(img, (2,0,1))
img = mx.nd.expand_dims(img, axis=0)

sym, arg_params, aux_params = mx.model.load_checkpoint('RRDB_4x', 0)
model = mx.mod.Module(symbol=sym, label_names=None, context=ctx)
model.bind(for_training=False, data_shapes=[('data', img.shape)])

model.set_params(arg_params, aux_params)

from collections import namedtuple
Batch = namedtuple('Batch', ['data'])

model.forward(Batch([img]), is_train=False)
prob = model.get_outputs()[0].asnumpy()
prob = np.squeeze(prob)

from PIL import Image
prob = (prob.transpose(1,2,0)*255).astype(np.uint8)
img = Image.fromarray(prob)
img.save('RRDB_4x.jpg')

結果の表示

  • 上から「元画像」「ESRGAN」「RRDB」
  • 「ESRGAN」はうまくいっていない

f:id:touch-sp:20181016160124j:plain

from PIL import Image

original = Image.open('dog.jpg').resize((1024,1024))
x0 = original.crop((220,200,850,650))

esrgan = Image.open('ESRGAN_4x.jpg')
x1 = esrgan.crop((220,200,850,650))

rrdb = Image.open('RRDB_4x.jpg')
x2 = rrdb.crop((220,200,850,650))

img = Image.new('RGB', (630, 450*3))
img.paste(x0, (0, 0))
img.paste(x1, (0, 450))
img.paste(x2, (0, 900))

img.save('result2.jpg')

2018年10月18日追記

  • 学習済みモデルはGluonでロードできた
import numpy as np
import mxnet as mx
from mxnet import image, gluon

img = image.imread('dog.jpg')
img = img.astype(np.float32)/255
img = mx.nd.transpose(img, (2,0,1))
img = mx.nd.expand_dims(img, axis=0)

net = gluon.nn.SymbolBlock.imports("ESRGAN_4x-symbol.json", ['data'], "ESRGAN_4x-0000.params")
output = net(img)
output = mx.nd.squeeze(output)
output = output.asnumpy()
output = (output.transpose(1,2,0)*255).astype(np.uint8)

from PIL import Image
img = Image.fromarray(output)
img.save('ESRGAN_4x.jpg')
import numpy as np
import mxnet as mx
from mxnet import image, gluon

img = image.imread('dog.jpg')
img = img.astype(np.float32)/255
img = mx.nd.transpose(img, (2,0,1))
img = mx.nd.expand_dims(img, axis=0)

net = gluon.nn.SymbolBlock.imports("RRDB_4x-symbol.json", ['data'], "RRDB_4x-0000.params")
output = net(img)
output = mx.nd.squeeze(output)
output = output.asnumpy()
output = (output.transpose(1,2,0)*255).astype(np.uint8)

from PIL import Image
img = Image.fromarray(output)
img.save('RRDB_4x.jpg')