目的
画像生成モデルに自前の物体を描画させる試みです。以前にもDreamBoothやLoRAで同様のことをやっています。この物体を描写させます。具体的にはビーチに立たせてみます。
ViCo (Detail-Preserving Visual Condition for Personalized Text-to-Image Generation) という手法を使います。
DreamBoothやLoRAの記事は最後にリンクを貼っておきますのでよかったら見て下さい。
結果
ん~、いまいちです。
設定を煮詰めると良くなるのでしょうか?
環境
Windows 11 CUDA 11.7 Python 3.11
Python環境構築
yapfは0.40.1までを使用する必要があります。0.40.2以降を使用すると以下のエラーが出ます。TypeError: FormatCode() got an unexpected keyword argument 'verify'
pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 --index-url https://download.pytorch.org/whl/cu117 pip install yapf==0.40.1 pip install mmcv==2.0.1 -f https://download.openmmlab.com/mmcv/dist/cu117/torch2.0.0/index.html pip install openmim==0.3.9 pip install mmagic==1.1.0 pip install accelerate==0.23.0 pip install albumentations==1.3.1
準備
リポジトリのクローン
git clone https://github.com/open-mmlab/mmagic
学習データの準備
こちらから「imagenet_templates_small.txt」をダウンロードします。こちらから犬の写真を5枚ダウンロードします。
ファイル構造はこのようにします。
mmagic └─data └─vico │ imagenet_templates_small.txt │ └─robot 00.jpg 01.jpg 02.jpg 03.jpg 04.jpg . .
学習Configファイル
「mmagic/configs/vico/vico.py」を以下のように書き換え「my_config.py」としました。
_base_ = '../_base_/gen_default_runtime.py'
randomness = dict(seed=2023, diff_rank_seed=True)
# dtype="fp32"
# config for model
stable_diffusion_v15_url = 'model/stable-diffusion-v1-5'
data_root = './data/vico'
concept_dir = 'robot'
# 1 for using image cross
image_cross_layers = [
# down blocks (2x transformer block) * (3x down blocks) = 6
0,
0,
0,
0,
0,
0,
# mid block (1x transformer block) * (1x mid block)= 1
0,
# up blocks (3x transformer block) * (3x up blocks) = 9
0,
1,
0,
1,
0,
1,
0,
1,
0,
]
reg_loss_weight: float = 5e-4
placeholder: str = 'S*'
val_prompts = ['a photo of a S*']
initialize_token: str = 'robot'
num_vectors_per_token: int = 1
model = dict(
type='ViCo',
vae=dict(
type='AutoencoderKL',
from_pretrained=stable_diffusion_v15_url,
subfolder='vae'),
unet=dict(
type='UNet2DConditionModel',
subfolder='unet',
from_pretrained=stable_diffusion_v15_url),
text_encoder=dict(
type='ClipWrapper',
clip_type='huggingface',
pretrained_model_name_or_path=stable_diffusion_v15_url,
subfolder='text_encoder'),
tokenizer=stable_diffusion_v15_url,
scheduler=dict(
type='DDPMScheduler',
from_pretrained=stable_diffusion_v15_url,
subfolder='scheduler'),
test_scheduler=dict(
type='DDIMScheduler',
from_pretrained=stable_diffusion_v15_url,
subfolder='scheduler'),
# dtype=dtype,
data_preprocessor=dict(type='DataPreprocessor', data_keys=None),
image_cross_layers=image_cross_layers,
reg_loss_weight=reg_loss_weight,
placeholder=placeholder,
initialize_token=initialize_token,
num_vectors_per_token=num_vectors_per_token,
val_prompts=val_prompts)
train_cfg = dict(max_iters=2000)
paramwise_cfg = dict(
custom_keys={
'image_cross_attention': dict(lr_mult=2e-3),
'trainable_embeddings': dict(lr_mult=1.0)
})
optim_wrapper = dict(
optimizer=dict(type='AdamW', lr=0.005, weight_decay=0.01),
constructor='DefaultOptimWrapperConstructor',
paramwise_cfg=paramwise_cfg,
accumulative_counts=1)
pipeline = [
dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
dict(type='LoadImageFromFile', key='img_ref', channel_order='rgb'),
dict(type='Resize', keys=['img', 'img_ref'], scale=(512, 512)),
dict(
type='PackInputs',
keys=['img', 'img_ref'],
data_keys='prompt',
meta_keys=[
'img_channel_order', 'img_color_type', 'img_ref_channel_order',
'img_ref_color_type'
])
]
dataset = dict(
type='TextualInversionDataset',
data_root=data_root,
concept_dir=concept_dir,
placeholder=placeholder,
template='data/vico/imagenet_templates_small.txt',
with_image_reference=True,
pipeline=pipeline)
train_dataloader = dict(
dataset=dataset,
num_workers=16,
sampler=dict(type='InfiniteSampler', shuffle=True),
persistent_workers=True,
batch_size=1)
val_cfg = val_evaluator = val_dataloader = None
test_cfg = test_evaluator = test_dataloader = None
# hooks
default_hooks = dict(logger=dict(interval=10))
custom_hooks = [
dict(
type='VisualizationHook',
interval=200,
fixed_input=True,
# visualize train dataset
vis_kwargs_list=dict(type='Data', name='fake_img'),
n_samples=1)
]
実行
python tools/train.py configs/vico/my_config.py
推論
保存されたpthを変換する必要があります。import torch def extract_vico_parameters(state_dict): new_state_dict = dict() for k, v in state_dict.items(): if 'image_cross_attention' in k or 'trainable_embeddings' in k: new_k = k.replace('module.', '') new_state_dict[new_k] = v return new_state_dict checkpoint = torch.load("work_dirs/my_config/iter_2000.pth") new_checkpoint = extract_vico_parameters(checkpoint['state_dict']) torch.save(new_checkpoint, "work_dirs/my_config/robot2000.pth")
これによって「iter_2000.pth」から新たに「robot2000.pth」が作成されます。
推論には「robot2000.pth」を使います。
import torch import os from mmengine import Config from PIL import Image from mmagic.registry import MODELS from mmagic.utils import register_all_modules register_all_modules() save_dir = "work_dirs/my_config" cfg = Config.fromfile(os.path.join(save_dir, "my_config.py")) state_dict = torch.load(os.path.join(save_dir, "robot2000.pth")) vico = MODELS.build(cfg.model) vico.load_state_dict(state_dict, strict=False) vico = vico.cuda() prompt = ["a photo of a S* on the beach"] reference = "data/vico/robot/1_220.png" image_ref = Image.open(reference) os.makedirs(os.path.join(save_dir, "infer_results"), exist_ok=True) for i in range(3): seed=10000 + i * 999 with torch.no_grad(): output = vico.infer( prompt=prompt, image_reference=image_ref, seed=seed, num_images_per_prompt=1)['samples'][0] output.save(os.path.join(save_dir, "infer_results", f"infer_seed{seed}.png"))
関連記事
touch-sp.hatenablog.comtouch-sp.hatenablog.com