OpenMMLab の MMEditing を使ってサクッと超解像

公開日:2021年11月17日
最終更新日:2022年9月15日

はじめに

以前MXNetを使った超解像の記事を書きました。
touch-sp.hatenablog.com
touch-sp.hatenablog.com
今回はPyTorchとMMEditingを使ってみたいと思います。

環境

二つの環境で動作確認できています。

Ubuntu on WSL2

Ubuntu 20.04 LTS on WSL2 (Windows 11) 
CUDA 11.3.1
Python 3.9.5

Windows

Windows 11
CUDA 11.6.2
Python 3.9.13

Python環境構築

すべてpipでインストール可能です。
(CUDAのバージョンを自身の環境に合わせて下さい)

pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.12.0/index.html
pip install mmedit
pip install openmim
pip install mmengine

Pythonスクリプト

Python環境が整えばあとは以下のスクリプトを実行するのみです。
リポジトリのクローン(git clone)などは必要ありません。
必要なファイルはスクリプト内でダウンロードできるようになっています。
画像の部分は好きな画像に変更してください。

import os
from PIL import Image
import torch
from torchvision.datasets.utils import download_url
from mmedit.apis import init_model, restoration_inference
from mmedit.core import tensor2img
from mim.commands.download import download

device = 'cuda' if torch.cuda.is_available() else 'cpu' 

img_url ='https://github.com/open-mmlab/mmediting/raw/master/tests/data/lq/baboon_x4.png'
img_fname = img_url.split('/')[-1]
download_url(img_url, root = '.', filename = img_fname)

os.makedirs('models', exist_ok=True)

checkpoint_name = 'esrgan_x4c64b23g32_g1_400k_div2k'
config_fname = checkpoint_name + '.py'

checkpoint = download(package="mmedit", configs=[checkpoint_name], dest_root="models")[0]

model = init_model(os.path.join('models', config_fname), os.path.join('models', checkpoint), device = device)

output = restoration_inference(model, img_fname)
output = tensor2img(output)

pil_img = Image.fromarray(output[:,:,::-1])
pil_img.show()

関連記事

touch-sp.hatenablog.com