【DragGAN】【StyleGAN-Human】DragGANで自前の人物画像を使ってみました(PyTorch=2.0.1+cu117)

はじめに

以前「PyTorch=1.12.1+cu116」を使ってやったことを「PyTorch=2.0.1+cu117」で実行することに成功しました。
touch-sp.hatenablog.com

環境

Ubuntu 22.04 on WSL2
CUDA 11.7.1
cuDNN 8.5.0
Python 3.10
torch==2.0.1+cu117

Python環境構築

pip install -U setuptools wheel
pip install -r https://raw.githubusercontent.com/dai-ichiro/myEnvironments/main/StyleGAN-Human_plus_DragGAN/requirements_cu117_torch201.txt
pip install paddlepaddle-gpu==2.4.2.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
git clone https://github.com/PaddlePaddle/PaddleSeg
cd PaddleSeg
git checkout release/2.5
pip install -v -e .

Gradio

前回と違ってGradioを使ってみます。

まずはDragGANのリポジトリをクローンします。

git clone https://github.com/XingangPan/DragGAN



「DragGAN/legacy.py」を開いて40行目と41行目を削除するかコメントアウトする必要があります。

#assert isinstance(data['G'], torch.nn.Module)
#assert isinstance(data['D'], torch.nn.Module)



次に「DragGAN」フォルダ直下に「checkpoints」フォルダを作成し、そのフォルダ内にStyleGAN-Humanを使って作成した「model_girl.pkl」をコピーします。
その際に「model_girl.pkl」は「stylegan_human.pkl」に名前を変更する必要があります。
実際にはファイル名のどこかに「stylegan_human」という文字列が含まれていれば問題ないようですが今回は単純に「stylegan_human.pkl」としました。

「DragGAN」フォルダ直下にStyleGAN-Humanを使って作成した「0.pt」をコピーします。

「DragGAN/visualizer_drag_gradio.py」の中を2か所変更する必要があります。

変更前

def init_images(global_state):
    """This function is called only ones with Gradio App is started.
    0. pre-process global_state, unpack value from global_state of need
    1. Re-init renderer
    2. run `renderer._render_drag_impl` with `is_drag=False` to generate
       new image
    3. Assign images to global state and re-generate mask
    """

    if isinstance(global_state, gr.State):
        state = global_state.value
    else:
        state = global_state

    state['renderer'].init_network(
        state['generator_params'],  # res
        valid_checkpoints_dict[state['pretrained_weight']],  # pkl
        state['params']['seed'],  # w0_seed,
        None,  # w_load
init_pkl = 'stylegan2_lions_512_pytorch'

変更後

def init_images(global_state):
    """This function is called only ones with Gradio App is started.
    0. pre-process global_state, unpack value from global_state of need
    1. Re-init renderer
    2. run `renderer._render_drag_impl` with `is_drag=False` to generate
       new image
    3. Assign images to global state and re-generate mask
    """

    if isinstance(global_state, gr.State):
        state = global_state.value
    else:
        state = global_state

    w_pivot = torch.load('0.pt')

    state['renderer'].init_network(
        state['generator_params'],  # res
        valid_checkpoints_dict[state['pretrained_weight']],  # pkl
        state['params']['seed'],  # w0_seed,
        w_pivot,  # w_load
init_pkl = 'stylegan2_lions_512_pytorch'



あとは実行するだけです。

python visualizer_drag_gradio.py



このエントリーをはてなブックマークに追加