OpenMMLab の MMagic で ViCo (Detail-Preserving Visual Condition for Personalized Text-to-Image Generation) を試してみる

目的

画像生成モデルに自前の物体を描画させる試みです。

以前にもDreamBoothやLoRAで同様のことをやっています。


この物体を描写させます。具体的にはビーチに立たせてみます。

ViCo (Detail-Preserving Visual Condition for Personalized Text-to-Image Generation) という手法を使います。

DreamBoothやLoRAの記事は最後にリンクを貼っておきますのでよかったら見て下さい。

結果




ん~、いまいちです。

設定を煮詰めると良くなるのでしょうか?

環境

Windows 11
CUDA 11.7
Python 3.11

Python環境構築

yapfは0.40.1までを使用する必要があります。

0.40.2以降を使用すると以下のエラーが出ます。

TypeError: FormatCode() got an unexpected keyword argument 'verify'


pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 --index-url https://download.pytorch.org/whl/cu117
pip install yapf==0.40.1
pip install mmcv==2.0.1 -f https://download.openmmlab.com/mmcv/dist/cu117/torch2.0.0/index.html
pip install openmim==0.3.9
pip install mmagic==1.1.0
pip install accelerate==0.23.0
pip install albumentations==1.3.1

準備

リポジトリのクローン

git clone https://github.com/open-mmlab/mmagic

学習データの準備

こちらから「imagenet_templates_small.txt」をダウンロードします。

こちらから犬の写真を5枚ダウンロードします。

ファイル構造はこのようにします。

mmagic
 └─data
    └─vico
        │  imagenet_templates_small.txt
        │  
        └─robot
                00.jpg
                01.jpg
                02.jpg
                03.jpg
                04.jpg
                 .
                 .

学習

Configファイル

「mmagic/configs/vico/vico.py」を以下のように書き換え「my_config.py」としました。

_base_ = '../_base_/gen_default_runtime.py'

randomness = dict(seed=2023, diff_rank_seed=True)
# dtype="fp32"
# config for model
stable_diffusion_v15_url = 'model/stable-diffusion-v1-5'

data_root = './data/vico'
concept_dir = 'robot'

# 1 for using image cross
image_cross_layers = [
    # down blocks (2x transformer block) * (3x down blocks) = 6
    0,
    0,
    0,
    0,
    0,
    0,
    # mid block (1x transformer block) * (1x mid block)= 1
    0,
    # up blocks (3x transformer block) * (3x up blocks) = 9
    0,
    1,
    0,
    1,
    0,
    1,
    0,
    1,
    0,
]
reg_loss_weight: float = 5e-4
placeholder: str = 'S*'
val_prompts = ['a photo of a S*']
initialize_token: str = 'robot'
num_vectors_per_token: int = 1

model = dict(
    type='ViCo',
    vae=dict(
        type='AutoencoderKL',
        from_pretrained=stable_diffusion_v15_url,
        subfolder='vae'),
    unet=dict(
        type='UNet2DConditionModel',
        subfolder='unet',
        from_pretrained=stable_diffusion_v15_url),
    text_encoder=dict(
        type='ClipWrapper',
        clip_type='huggingface',
        pretrained_model_name_or_path=stable_diffusion_v15_url,
        subfolder='text_encoder'),
    tokenizer=stable_diffusion_v15_url,
    scheduler=dict(
        type='DDPMScheduler',
        from_pretrained=stable_diffusion_v15_url,
        subfolder='scheduler'),
    test_scheduler=dict(
        type='DDIMScheduler',
        from_pretrained=stable_diffusion_v15_url,
        subfolder='scheduler'),
    # dtype=dtype,
    data_preprocessor=dict(type='DataPreprocessor', data_keys=None),
    image_cross_layers=image_cross_layers,
    reg_loss_weight=reg_loss_weight,
    placeholder=placeholder,
    initialize_token=initialize_token,
    num_vectors_per_token=num_vectors_per_token,
    val_prompts=val_prompts)

train_cfg = dict(max_iters=2000)

paramwise_cfg = dict(
    custom_keys={
        'image_cross_attention': dict(lr_mult=2e-3),
        'trainable_embeddings': dict(lr_mult=1.0)
    })
optim_wrapper = dict(
    optimizer=dict(type='AdamW', lr=0.005, weight_decay=0.01),
    constructor='DefaultOptimWrapperConstructor',
    paramwise_cfg=paramwise_cfg,
    accumulative_counts=1)

pipeline = [
    dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
    dict(type='LoadImageFromFile', key='img_ref', channel_order='rgb'),
    dict(type='Resize', keys=['img', 'img_ref'], scale=(512, 512)),
    dict(
        type='PackInputs',
        keys=['img', 'img_ref'],
        data_keys='prompt',
        meta_keys=[
            'img_channel_order', 'img_color_type', 'img_ref_channel_order',
            'img_ref_color_type'
        ])
]
dataset = dict(
    type='TextualInversionDataset',
    data_root=data_root,
    concept_dir=concept_dir,
    placeholder=placeholder,
    template='data/vico/imagenet_templates_small.txt',
    with_image_reference=True,
    pipeline=pipeline)
train_dataloader = dict(
    dataset=dataset,
    num_workers=16,
    sampler=dict(type='InfiniteSampler', shuffle=True),
    persistent_workers=True,
    batch_size=1)
val_cfg = val_evaluator = val_dataloader = None
test_cfg = test_evaluator = test_dataloader = None

# hooks
default_hooks = dict(logger=dict(interval=10))
custom_hooks = [
    dict(
        type='VisualizationHook',
        interval=200,
        fixed_input=True,
        # visualize train dataset
        vis_kwargs_list=dict(type='Data', name='fake_img'),
        n_samples=1)
]

実行

python tools/train.py configs/vico/my_config.py

推論

保存されたpthを変換する必要があります。

import torch
def extract_vico_parameters(state_dict):
    new_state_dict = dict()
    for k, v in state_dict.items():
        if 'image_cross_attention' in k or 'trainable_embeddings' in k:
            new_k = k.replace('module.', '')
            new_state_dict[new_k] = v
    return new_state_dict

checkpoint = torch.load("work_dirs/my_config/iter_2000.pth")
new_checkpoint = extract_vico_parameters(checkpoint['state_dict'])
torch.save(new_checkpoint, "work_dirs/my_config/robot2000.pth")

これによって「iter_2000.pth」から新たに「robot2000.pth」が作成されます。

推論には「robot2000.pth」を使います。

import torch
import os
from mmengine import Config
from PIL import Image
from mmagic.registry import MODELS
from mmagic.utils import register_all_modules

register_all_modules()

save_dir = "work_dirs/my_config"
cfg = Config.fromfile(os.path.join(save_dir, "my_config.py"))
state_dict = torch.load(os.path.join(save_dir, "robot2000.pth"))

vico = MODELS.build(cfg.model)
vico.load_state_dict(state_dict, strict=False)
vico = vico.cuda()

prompt = ["a photo of a S* on the beach"]
reference = "data/vico/robot/1_220.png"
image_ref = Image.open(reference)

os.makedirs(os.path.join(save_dir, "infer_results"), exist_ok=True)

for i in range(3):
    seed=10000 + i * 999
    with torch.no_grad():
        output = vico.infer(
            prompt=prompt, 
            image_reference=image_ref,
            seed=seed,
            num_images_per_prompt=1)['samples'][0]

    output.save(os.path.join(save_dir, "infer_results", f"infer_seed{seed}.png"))

関連記事

touch-sp.hatenablog.com
touch-sp.hatenablog.com




このエントリーをはてなブックマークに追加