結果
学習に使用した画像の一部を下に載せておきます。
この人物を再現しようとしたのが上記画像です。
再現できていると言えるでしょうか?
はじめに
LoRA自体は以前からある手法です。それにPivotal Tuningを組み合わせる方法が公開されたのはごく最近です。huggingface.co
Diffusersから「train_dreambooth_lora_sdxl_advanced.py」という学習用スクリプトが公開されているのでやりかたは非常に簡単です。
データの準備
学習データ
120枚の学習データを作成しました。touch-sp.hatenablog.com
「face_dataset_for_train_120」というフォルダに保存しました。
正則化画像
他の顔写真を288枚用意しました。こちらも作成したものです。作成方法は後述するPythonスクリプトを見て下さい。「class_dataset_for_train_288」というフォルダに保存しました。
実行
データの準備が済んだらこちらから「train_dreambooth_lora_sdxl_advanced.py」をダウンロードしてきます。モデルも用意して下さい。今回は「fudukiMix_v20」を使用しています。
あとは実行するのみです。
accelerate launch train_dreambooth_lora_sdxl_advanced.py ` --pretrained_model_name_or_path="fudukiMix_v20_full_nonEMA" ` --instance_data_dir="face_dataset_for_train_120" ` --output_dir="lora-advanced-trained-xl-with-text_encoder_ti_rank32_prodigy_snr_gamma5.0" ` --mixed_precision="fp16" ` --instance_prompt="a photo of TOK woman" ` --resolution=1024 ` --train_batch_size=1 ` --gradient_accumulation_steps=4 ` --gradient_checkpointing ` --lr_scheduler="constant" ` --lr_warmup_steps=0 ` --max_train_steps=1600 ` --checkpointing_steps=400 ` --seed="0" ` --with_prior_preservation ` --class_data_dir="class_dataset_for_train_288" ` --num_class_images=288 ` --class_prompt="a photo of woman" ` --train_text_encoder_ti ` --token_abstraction="TOK" ` --rank=32 ` --optimizer="prodigy" ` --learning_rate=1.0 ` --text_encoder_lr=1.0 ` --snr_gamma=5.0
PC環境
Windows 11 CUDA 11.8 Python 3.11
Python環境構築
pip install torch==2.1.2+cu118 torchvision==0.16.2+cu118 --index-url https://download.pytorch.org/whl/cu118 pip install git+https://github.com/huggingface/diffusers pip install accelerate transformers ftfy tensorboard Jinja2 scipy peft prodigyopt
補足
正則化画像の作り方
import os import torch from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler import itertools import argparse parser = argparse.ArgumentParser() parser.add_argument( '--repeat', type=int, default=2, help="nuber of repeat", ) parser.add_argument( '--model', type=str, required=True, help="path to model", ) args = parser.parse_args() repeat = args.repeat base_model_path = args.model pipe = AutoPipelineForText2Image.from_pretrained( base_model_path, torch_dtype=torch.float16, variant="fp16" ) pipe.scheduler = DPMSolverMultistepScheduler.from_config( pipe.scheduler.config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True ) pipe.to("cuda") negative_prompt = "hand, hands, finger, cleavage, illustration, 3d, 2d, painting, cartoons, sketch, watercolor, monotone, kimono, crossed eyes, strabismus" model_name = os.path.basename(base_model_path) save_folder = f"class_dataset_for_train_{model_name}" os.makedirs(save_folder, exist_ok=True) woman_list = ["japanese woman", "woman"] age_list = ["25yo", "35yo", "45yo"] shot_list = ["face shot", "close-up shot"] angle_list = ["straight-on", "from side"] lighting_list = ["natural lighting", "cinematic lighting", "studio lighting"] for (woman, age, shot, angle, lighting) in itertools.product(woman_list, age_list, shot_list, angle_list, lighting_list): for i in range(repeat): prompt = f"{woman}, {age}, {shot}, {angle}, {lighting}, best quality" image = pipe( prompt=prompt, negative_prompt=negative_prompt, num_samples=1, width=1024, height=1024, num_inference_steps=40, guidance_scale=7.5 ).images[0] save_fname = ''.join(f"{model_name}_{woman}_{age}_{shot}_{angle}_{lighting}_{i}.png".split()) image.save(os.path.join(save_folder, save_fname))
推論のためのPythonスクリプト
from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler from safetensors.torch import load_file import torch import itertools import argparse from pathlib import Path parser = argparse.ArgumentParser() parser.add_argument( '--steps', type=int, required=True, help="number of trained steps" ) parser.add_argument( '--folder', type=str, default="lora-trained-xl", help="path to folder in which the lora-trained safetensors is saved" ) args = parser.parse_args() checkpoint = args.steps checkpoint_folder = args.folder pipe = AutoPipelineForText2Image.from_pretrained( "model/fudukiMix_v20", torch_dtype=torch.float16, variant="fp16" ) pipe.scheduler = DPMSolverMultistepScheduler.from_config( pipe.scheduler.config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True ) lora_model_id = Path(checkpoint_folder, f"checkpoint-{checkpoint}", "pytorch_lora_weights.safetensors").as_posix() ti_model_id = Path(checkpoint_folder, f"checkpoint-{checkpoint}", f"{checkpoint_folder}_emb.safetensors").as_posix() state_dict = load_file(ti_model_id) pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) pipe.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) pipe.load_lora_weights(lora_model_id, adapter_name="my_lora") pipe.to("cuda") prompts = [ "a photo of <s0><s1> woman, in the cafe, sitting on the sofa, 8k, RAW photo, best quality, masterpiece, photo-realistic, focus, professional lighting", "a photo of <s0><s1> woman, on the beach, sunset, 8k, RAW photo, best quality, masterpiece, photo-realistic, focus, professional lighting", "a photo of <s0><s1> woman, wearing a black dress, upper body, behind is the Eiffel Tower, 8k, RAW photo, best quality, masterpiece, photo-realistic, focus, professional lighting", "a photo of <s0><s1> woman, standing on street, 8k, RAW photo, detailed, plain white t-shirt, eye level angle" ] negative_prompt = "worst quality, low quality" #weights = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5] weights = [0.6, 0.7, 0.8] seed = 20240126 for (prompt, weight) in itertools.product(prompts, weights): generator = torch.manual_seed(seed) prompt_n = prompts.index(prompt) pipe.set_adapters(["my_lora"], adapter_weights=[weight]) image = pipe( prompt=prompt, negative_prompt=negative_prompt, generator=generator, num_inference_steps=35, width=1152, height=896, ).images[0] save_folder_path = Path(checkpoint_folder, f"result_checkpoint{checkpoint}", f"weight{weight}") save_folder_path.mkdir(parents=True, exist_ok=True) image.save(Path(save_folder_path, f"weight{weight}_{prompt_n}.png").as_posix())