【inpainting】従来のGAN(AOT-GAN)とPaintByExampleを比較してみる

はじめに

以前PaintByExampleというのを使って画像から犬を消しました。
touch-sp.hatenablog.com
今回はAOT-GANで同じことをやって結果を比較してみました。


AOT-GANはOpenMMLabのMMEditingから使用しています。

結果

元画像
左がAOT-GAN 右がPaintByExample

PaintByExampleの方が犬を消した後の芝生がうまく描画されています。


ただしよく見るとPaintByExampleの画像は女性の顔や手に持っているボールが崩れてしまっています。


一長一短といったところでしょうか。状況によって使い分けるのが良さそうです。

AOT-GANの使い方

環境

Python 3.10ではうまくいかなかったのでPython 3.9を使用しています。

Windows 11
CUDA 11.6.2
Python 3.9.13
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install mmcv-full==1.7.1 -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.13.0/index.html
pip install mmedit==0.16.0
pip install openmim==0.3.4
pip install mmengine==0.4.0

Pythonスクリプト

import argparse
import os
import mmcv
import torch
from mim.commands.download import download
from mmedit.apis import init_model, inpainting_inference
from mmedit.core import tensor2img

def parse_args():
    parser = argparse.ArgumentParser(description='Inpainting demo')
    
    parser.add_argument(
        '--image',
        type=str,
        help='path to input image file')

    parser.add_argument(
        '--mask',
        type=str,    
        help='path to input mask file')
    
    parser.add_argument('--device', type=int, default=0, help='CUDA device id')
    args = parser.parse_args()
    return args

def main():
    args = parse_args()

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    os.makedirs('models', exist_ok=True)

    checkpoint_name = 'AOT-GAN_512x512_4x12_places'
    config_fname = checkpoint_name + '.py'

    checkpoint = download(package="mmedit", configs=[checkpoint_name], dest_root="models")[0]

    model = init_model(os.path.join('models', config_fname), os.path.join('models', checkpoint), device = device)

    result = inpainting_inference(model, args.image, args.mask)
    result = tensor2img(result, min_max=(-1, 1))[..., ::-1]

    mmcv.imwrite(result, 'result.png')
    mmcv.imshow(result, 'predicted inpainting result')

if __name__ == '__main__':
    main()

実行

上記スクリプトは「inpaint.py」という名前で保存しています。

python inpaint.py --image Dog-Park.jpg --mask mask.png