【PyTorch】JoJoGANというものを使わせて頂きました

はじめに

今回JoJoGANというのを使わせて頂きました。
「ジョジョの奇妙な冒険」のJoJoです。
github.com
以前にTargetCLIPについて記事を書きました。
touch-sp.hatenablog.com
二つは方法が異なるようですが目的は同じように見えます。
どちらもStyleGANをベースとしています。


環境構築

Ubuntu 20.04 on WSL2
Python 3.8.10




Ninjaを使うとのことなのでGitHubにあるとおりにインストールしました。

wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
sudo unzip ninja-linux.zip -d /usr/local/bin/
sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force




venvでPython仮想環境を作って必要なモジュールをインストールしました。

pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
pip install tqdm gdown scikit-learn==0.22 scipy lpips dlib opencv-python wandb matplotlib

dlibをインストールする際にCMakeが必要になるかもしれません。エラーがでるようならインストールしてみて下さい。

sudo apt install cmake

それでもダメならこちらもインストールしてみて下さい。

 sudo apt install build-essential




リポジトリのクローン

リポジトリをクローンして「models」フォルダを作成しました。

git clone https://github.com/mchong6/JoJoGAN.git
cd JoJoGAN
mkdir models




必要なファイルのダウンロード

3つのファイルをダウンロードしました。いずれも先ほど作った「models」フォルダに保存します。

こちらから「e4e_ffhq_encode.pt」というファイルをダウンロード
こちらから「stylegan2-ffhq-config-f.pt」というファイルをダウンロード
こちらから「jojo.pt」というファイルをダウンロード


Pythonスクリプトの実行

クローンしたリポジトリのトップ「JoJoGAN」フォルダ内で以下のPythonスクリプトを実行しました。

import torch
from torchvision import transforms, utils
from util import *
from PIL import Image
import os
from model import *
from e4e_projection import projection as e4e_projection

device = 'cuda'

latent_dim = 512

# Load original generator
generator = Generator(1024, latent_dim, 8, 2).to(device)
ckpt = torch.load('models/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
mean_latent = generator.mean_latent(10000)

plt.rcParams['figure.dpi'] = 150

filename = 'iu.jpeg'
filepath = f'test_input/{filename}'

name = strip_path_extension(filepath)+'.pt'

# aligns and crops face
aligned_face = align_face(filepath)

my_w = e4e_projection(aligned_face, name, device).unsqueeze(0)

#@param ['art', 'arcane_multi', 'supergirl', 'arcane_jinx', 'arcane_caitlyn', 'jojo_yasuho', 'jojo', 'disney']
pretrained = 'jojo' 

ckpt = f'{pretrained}.pt'

ckpt = torch.load(os.path.join('models', ckpt), map_location=lambda storage, loc: storage)
generator.load_state_dict(ckpt["g"], strict=False)

n_sample =  1

seed = 3000
torch.manual_seed(seed)

with torch.no_grad():
    generator.eval()
    my_sample = generator(my_w, input_is_latent=True)

tensor = utils.make_grid(my_sample, normalize=True, range=(-1, 1), nrow=1)
result_image = transforms.ToPILImage()(tensor)

if pretrained == 'arcane_multi':
    style_path = f'style_images_aligned/arcane_jinx.png'
else:   
    style_path = f'style_images_aligned/{pretrained}.png'

style_image = Image.open(style_path)

aligned_face.show()
style_image.show()
result_image.show()




結果

このようになりました。一番右の画像が今回作成された画像です。




補足

Pythonスクリプト内に以下の数行を追加しておくとGoogle Driveからのファイルダウンロードを省略できます。

from gdown import download

if not os.path.exists('models/stylegan2-ffhq-config-f.pt'):
    download('https://drive.google.com/uc?id=1Yr7KuD959btpmcKGAUsbAk5rPjX2MytK', 'models/stylegan2-ffhq-config-f.pt', quiet = False)

if not os.path.exists('models/e4e_ffhq_encode.pt'):
    download('https://drive.google.com/uc?id=1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7', 'models/e4e_ffhq_encode.pt', quiet = False)

if not os.path.exists('models/jojo.pt'):
    download('https://drive.google.com/uc?id=13cR2xjIBj8Ga5jMO7gtxzIJj2PDsBYK4', 'models/jojo.pt', quiet = False)

2022年8月3日追記

つづきの記事を書きました。
touch-sp.hatenablog.com