【超解像】MMEditingでTexture Transformer Network for Image Super-Resolution (TTSR)を使ってみる

公開日:2022年7月18日
最終更新日:2022年9月15日

はじめに

以前から何度か超解像はやってきました。
touch-sp.hatenablog.com
touch-sp.hatenablog.com


今回はTexture Transformer Network for Image Super-Resolution (TTSR)というのを使ってみました。

結果

今までとの違いはモデルに参照画像を渡す必要がある点です。

左が元画像(低解像度の画像)です。Pillowで4倍に拡大したものを載せています。


真ん中が超解像画像(元画像をTTSRを使って4倍に拡大したもの)です。


右が参照画像です。


左と真ん中を比べると効果絶大なのがわかります。


画像はTTSR本家のGitHubから使わせてもらいました。
github.com


自分で用意したオリジナル画像でもやってみたのですが、残念ながらあまり良い結果は得られませんでした。

Pythonスクリプト

モデル定義のファイル、学習済みパラメーター、画像のダウンロードをすべてスクリプト内に落とし込みました。

以下のスクリプトを実行するのみです。

参照画像の縦、横は4の倍数でないとエラーが出るようです。その変更もスクリプト内に落とし込みました。

import os
import requests
from PIL import Image
from torchvision.datasets.utils import download_url
from mmedit.apis import init_model, restoration_inference
from mmedit.core import tensor2img
from mim.commands.download import download

lr_img_fname = 'lr.png'
img_url ='https://github.com/researchmm/TTSR/raw/master/test/demo/lr/0.png'
download_url(img_url, root = '.', filename = lr_img_fname)

ref_img_fname = 'rf.png'
img_url ='https://github.com/researchmm/TTSR/raw/master/test/demo/ref/0.png'
im = Image.open(requests.get(img_url, stream=True).raw)
width, height = im.size
width = (width // 4) * 4
height = (height //4) * 4
im = im.crop((0, 0, width, height))
im.save(ref_img_fname)

os.makedirs('models', exist_ok=True)

checkpoint_name = 'ttsr-gan_x4_c64b16_g1_500k_CUFED'
config_fname = checkpoint_name + '.py'

checkpoint = download(package="mmedit", configs=[checkpoint_name], dest_root="models")[0]

model = init_model(os.path.join('models', config_fname), os.path.join('models', checkpoint), device = 'cuda')

output = restoration_inference(model, lr_img_fname, ref_img_fname)
output = tensor2img(output)

pil_img = Image.fromarray(output[:,:,::-1])
pil_img.show()

動作環境

Windows 11
CUDA 11.6.2
Python 3.9.13

Python環境の構築

すべてpipでインストール可能でした。

pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.12.0/index.html
pip install mmedit
pip install openmim
pip install mmengine
absl-py==1.2.0
addict==2.4.0
av==9.2.0
cachetools==5.2.0
certifi==2022.9.14
charset-normalizer==2.1.1
click==8.1.3
colorama==0.4.5
commonmark==0.9.1
cycler==0.11.0
facexlib==0.2.5
filterpy==1.4.5
fonttools==4.37.1
google-auth==2.11.0
google-auth-oauthlib==0.4.6
grpcio==1.48.1
idna==3.4
importlib-metadata==4.12.0
kiwisolver==1.4.4
llvmlite==0.39.1
lmdb==1.3.0
Markdown==3.4.1
MarkupSafe==2.1.1
matplotlib==3.5.3
mmcv-full==1.6.1
mmedit==0.15.2
mmengine==0.1.0
model-index==0.1.11
numba==0.56.2
numpy==1.23.3
oauthlib==3.2.1
opencv-python==4.6.0.66
openmim==0.3.1
ordered-set==4.1.0
packaging==21.3
pandas==1.4.4
Pillow==9.2.0
protobuf==3.19.5
pyasn1==0.4.8
pyasn1-modules==0.2.8
Pygments==2.13.0
pyparsing==3.0.9
python-dateutil==2.8.2
pytz==2022.2.1
PyYAML==6.0
regex==2022.9.13
requests==2.28.1
requests-oauthlib==1.3.1
rich==12.5.1
rsa==4.9
scipy==1.9.1
six==1.16.0
tabulate==0.8.10
tensorboard==2.10.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
termcolor==2.0.1
torch==1.12.1+cu116
torchvision==0.13.1+cu116
tqdm==4.64.1
typing_extensions==4.3.0
urllib3==1.26.12
Werkzeug==2.2.2
yapf==0.32.0
zipp==3.8.1