RTX 4090 (VRAM 24GB) を使って FLUX.1-dev の LoRA学習を行う

はじめに

FLUX.1-devに特定の人物を学習させることが目的です。

PC環境

Windows 11
CUDA 11.8
Python 3.12

元画像

用意したのは1枚の画像です。
SDLX派生モデルで作成したものです。
この人物を学習させます。

最終結果

a photo of f5h8_woman holding a sign that says 'I LOVE LoRA!

「f5h8_woman」は適当に作ったトリガーワードです。

LoRA学習用データの作成

こちらと同じ方法で作りました。
touch-sp.hatenablog.com
この時点で元画像と少し顔が異なってしまいます。
プロンプトを書きこむテキストファイルも同時に作成するように一部改変しています。
全部で25枚の画像を作成しました。

import os
import cv2
from insightface.app import FaceAnalysis
from insightface.utils import face_align
import torch
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
import argparse

from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDPlusXL

parser = argparse.ArgumentParser()
parser.add_argument(
    '--repeat',
    type=int,
    default=5,
    help="nuber of repeat",
)
args = parser.parse_args()

repeat = args.repeat
image_size = 640  # image_size%112 or image_size%128 must be 0

app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))

image = cv2.imread("face.png")
faces = app.get(image)

faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
face_image = face_align.norm_crop(image, landmark=faces[0].kps, image_size=image_size)

#base_model_path = "fudukiMix_v20"
base_model_path = "Juggernaut-XL-v9"
#base_model_path = "stable-diffusion-xl-base-1.0"

image_encoder_path = "CLIP-ViT-H-14-laion2B-s32B-b79K"
ip_ckpt = "IP-Adapter-FaceID/ip-adapter-faceid-plusv2_sdxl.bin"
device = "cuda"

noise_scheduler = DPMSolverMultistepScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    steps_offset=1,
    algorithm_type="sde-dpmsolver++",
    use_karras_sigmas=True,
)
pipe = StableDiffusionXLPipeline.from_pretrained(
    base_model_path,
    scheduler=noise_scheduler,
    torch_dtype=torch.float16,
    variant="fp16"
)

ip_model = IPAdapterFaceIDPlusXL(pipe, image_encoder_path, ip_ckpt, device)

negative_prompt = "mole, beauty spot, hand, finger, earring, cleavage, illustration, 3d, 2d, painting, cartoons, sketch, watercolor, monotone, kimono, crossed eyes, strabismus"

model_name = os.path.basename(base_model_path)
save_folder = f"face_dataset_for_train{model_name}"
os.makedirs(save_folder, exist_ok=True)

prompts = [
    "a photo of {trigger_word}, in the cafe, sitting on the sofa, 8k, RAW photo, best quality, masterpiece, photo-realistic, focus",
    "a photo of {trigger_word}, on the beach, 8k, RAW photo, best quality, masterpiece, photo-realistic, focus, natural lighting",
    "a photo of {trigger_word}, wearing a black dress, upper body, behind is the Eiffel Tower, 8k, RAW photo, best quality, masterpiece, photo-realistic",
    "a photo of {trigger_word}, standing on street, 8k, RAW photo, detailed, plain white t-shirt, eye level angle",
    "a photo of {trigger_word}, smiling, with arms outstretched, in the forest, natural lighting, RAW photo, best quality"
]

text_from_prompt =[
    "a photo of [trigger] sitting on the sofa in the cafe",
    "a photo of [trigger] on the beach",
    "a photo of [trigger] with a black dress Standing in front of the Eiffel Tower",
    "a photo of [trigger] with plain white t-shirt standing on street",
    "a photo of [trigger] with arms outstretched in the forest"
]

for i in range(repeat):
    for j, prompt in enumerate(prompts):
        prompt=prompt.format(trigger_word="a japanese woman")
        images = ip_model.generate(
            prompt=prompt,
            negative_prompt=negative_prompt,
            face_image=face_image,
            faceid_embeds=faceid_embeds,
            shortcut=True,
            s_scale=1.0,
            num_samples=1,
            width=1024,
            height=1024,
            num_inference_steps=40,
            guidance_scale=7.5
        )
        save_fname = ''.join(f"{prompt}_{i}".split())
        images[0].save(os.path.join(save_folder, f"{save_fname}.png"))

        with open(os.path.join(save_folder, f"{save_fname}.txt"), "w", encoding="utf-8") as f:
            f.write(text_from_prompt[j])

LoRA学習

こちらに従いました。
github.com
学習用データを作成した後、学習を開始するのは一行です。

python run.py my_config.yml



引数の「my_config.yml」はこのようにしました。

job: extension
config:
  name: flux_lora_v1_0911_2
  process:
  - type: sd_trainer
    training_folder: output
    device: cuda:0
    trigger_word: f5h8_woman
    network:
      type: lora
      linear: 16
      linear_alpha: 16
    save:
      dtype: float16
      save_every: 400
      max_step_saves_to_keep: 4
      push_to_hub: false
    datasets:
    - folder_path: 25faces
      caption_ext: txt
      caption_dropout_rate: 0.05
      shuffle_tokens: false
      cache_latents_to_disk: true
      resolution:
      - 512
      - 768
      - 1024
    train:
      batch_size: 1
      steps: 1200
      gradient_accumulation_steps: 1
      train_unet: true
      train_text_encoder: false
      gradient_checkpointing: true
      noise_scheduler: flowmatch
      optimizer: adamw8bit
      lr: 0.0001
      linear_timesteps: true
      ema_config:
        use_ema: true
        ema_decay: 0.99
      dtype: bf16
    model:
      name_or_path: FLUX.1-dev
      is_flux: true
      quantize: true
      low_vram: true
    sample:
      sampler: flowmatch
      sample_every: 400
      width: 1024
      height: 1024
      prompts:
      - A close-up portrait of f5h8_woman, holding a sign that says 'I LOVE PROMPTS!'
      neg: ''
      seed: 42
      walk_seed: true
      guidance_scale: 4
      sample_steps: 20

推論

import torch 
from diffusers import FluxPipeline
import gc

def flush():
    gc.collect()
    torch.cuda.empty_cache()

model_id = "FLUX.1-dev"

prompt="a photo of f5h8_woman holding a sign that says 'I LOVE LoRA!"

pipeline = FluxPipeline.from_pretrained(
        model_id,
        transformer=None,
        vae=None
).to("cuda")

with torch.no_grad():
    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
        prompt=prompt,
        prompt_2=None,
        #max_sequence_length=256
    )

del pipeline
flush()

pipeline = FluxPipeline.from_pretrained(
    model_id,
    text_encoder=None,
    text_encoder_2=None,
    tokenizer=None,
    tokenizer_2=None,
    torch_dtype=torch.bfloat16
)

pipeline.load_lora_weights("flux_lora_v1_0911_2/flux_lora_v1_0911_2_000000800.safetensors")

pipeline.enable_sequential_cpu_offload()

seed = 12345
generator = torch.Generator().manual_seed(seed)
image = pipeline(
    prompt_embeds=prompt_embeds.bfloat16(),
    pooled_prompt_embeds=pooled_prompt_embeds.bfloat16(),
    width=1024,
    height=1024,
    num_inference_steps=27,
    generator=generator,
    guidance_scale=3.5,
    joint_attention_kwargs={"scale": 1.0},
    
).images[0]

image.save(f"lora_result_seed{seed}.jpg")





このエントリーをはてなブックマークに追加