【JoJoGAN】自前の顔写真に対してJoJoGANの学習済みモデルを使用する

はじめに

JoJoGANはこちらです。
github.com
以前にも使用させて頂きました。
touch-sp.hatenablog.com
前回はデモを動かす程度でした。


今回は自分で用意した顔写真に対して学習済みモデルを使用してみます。

OS環境

Ubuntu 20.04 on WSL2
CUDA Toolkit 11.3.1
Python 3.9.5

最初に以下の2つをインストールしました。

sudo apt install cmake
sudo apt install build-essential

次にNinjaをインストールしました。

wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
sudo unzip ninja-linux.zip -d /usr/local/bin/
sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force

Python環境

以下をインストールしました。すべてpipでインストール可能でした。

pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install cython
pip install tqdm gdown scikit-learn==0.22 scipy lpips dlib opencv-python wandb matplotlib

実行

環境構築ができたら、リポジトリをクローンして以下のスクリプトを実行するだけです。

git clone https://github.com/mchong6/JoJoGAN.git
cd JoJoGAN



必要な学習済みモデルのダウンロードなどはスクリプト内で行われます。事前準備は必要ありません。


スクリプトを「jojoexe.py」、顔写真を「face.jpg」という名前で保存している仮定で、以下のようにするだけです。

python jojoexe.py face.jpg

これで10通りの学習済みモデルが一気に適応されます。

Pythonスクリプト

import os
import sys
from gdown import download

import torch
from torchvision import transforms, utils
from util import *
from PIL import Image
from model import *
from e4e_projection import projection as e4e_projection

google_drive_paths = [
    ("stylegan2-ffhq-config-f.pt", "1Yr7KuD959btpmcKGAUsbAk5rPjX2MytK"),
    ("dlibshape_predictor_68_face_landmarks.dat", "11BDmNKS1zxSZxkgsEvQoKgFd8J264jKp"),
    ("e4e_ffhq_encode.pt", "1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7"),
    ("restyle_psp_ffhq_encode.pt", "1nbxCIVw9H3YnQsoIPykNEFwWJnHVHlVd"),
    ("arcane_caitlyn.pt", "1gOsDTiTPcENiFOrhmkkxJcTURykW1dRc"),
    ("arcane_caitlyn_preserve_color.pt", "1cUTyjU-q98P75a8THCaO545RTwpVV-aH"),
    ("arcane_jinx_preserve_color.pt", "1jElwHxaYPod5Itdy18izJk49K1nl4ney"),
    ("arcane_jinx.pt", "1quQ8vPjYpUiXM4k1_KIwP4EccOefPpG_"),
    ("disney.pt", "1zbE2upakFUAx8ximYnLofFwfT8MilqJA"),
    ("disney_preserve_color.pt", "1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi"),
    ("jojo.pt", "13cR2xjIBj8Ga5jMO7gtxzIJj2PDsBYK4"),
    ("jojo_preserve_color.pt", "1ZRwYLRytCEKi__eT2Zxv1IlV6BGVQ_K2"),
    ("jojo_yasuho.pt", "1grZT3Gz1DLzFoJchAmoj3LoM9ew9ROX_"),
    ("jojo_yasuho_preserve_color.pt", "1SKBu1h0iRNyeKBnya_3BBmLr4pkPeg_L")
]

for file_name, file_id in google_drive_paths:
    if not os.path.exists(os.path.join('models', file_name)):
        download(f'https://drive.google.com/uc?id={file_id}', os.path.join('models', file_name), quiet = False)

device = 'cuda'

latent_dim = 512

# Load original generator
generator = Generator(1024, latent_dim, 8, 2).to(device)
ckpt = torch.load('models/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
mean_latent = generator.mean_latent(10000)

plt.rcParams['figure.dpi'] = 150

filepath = sys.argv[1]

name = strip_path_extension(filepath)+'.pt'

# aligns and crops face
aligned_face = align_face(filepath)

my_w = e4e_projection(aligned_face, name, device).unsqueeze(0)

pretrained_list = google_drive_paths[4:]

os.makedirs('results', exist_ok=True)

for each_pretrained in pretrained_list:
    ckpt_name = each_pretrained[0]
    ckpt = torch.load(os.path.join('models', ckpt_name), map_location=lambda storage, loc: storage)
    generator.load_state_dict(ckpt["g"], strict=False)

    with torch.no_grad():
        generator.eval()
        my_sample = generator(my_w, input_is_latent=True)

    tensor = utils.make_grid(my_sample, normalize=True, range=(-1, 1), nrow=1)
    result_image = transforms.ToPILImage()(tensor)

    save_filename = ckpt_name.replace('.pt', '.png')
    result_image.save(os.path.join('results', save_filename))

結果

今回はフリー素材「ぱくたそ」から顔写真を使わせて頂きました。
こちらの写真です。






いっきに10枚の画像が保存されます。