FLUX.1-dev派生モデルで Image2Image をやってみる

元画像

ぱくたそからこちらの画像を使わせて頂きました。
www.pakutaso.com

使用したモデル

CIVITAIから「xeBlenderFlux_01.safetensors」をダウンロードして使わせて頂きました。

こちらと同じ方法でいったんDiffusersフォーマットに変換しました。
touch-sp.hatenablog.com

結果

左からStrength 0.5 → 0.6 → 0.7 → 0.8 です。



RTX 3080 Laptop (VRAM 16GB)で測定した結果です。
時間は4枚の画像を生成するのにかかった時間です。

GPU 0 - Used memory: 11.44/16.00 GB
time: 526.96 sec

Pythonスクリプト

プロンプトにはImage2Textを行った結果を使いました。
冒頭に「A blender rendering artwork.」という一文を足しています。
touch-sp.hatenablog.com

import torch 
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel, FluxImg2ImgPipeline
from diffusers.utils import load_image
from decorator import gpu_monitor, time_monitor
from pathlib import Path

@time_monitor
@gpu_monitor(interval=0.5)
def main(model_id:str) -> None:

    nf4_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    model_nf4 = FluxTransformer2DModel.from_pretrained(
        model_id, subfolder="transformer",
        quantization_config=nf4_config,
        torch_dtype=torch.bfloat16
    )

    pipe = FluxImg2ImgPipeline.from_pretrained(
        model_id,
        transformer=model_nf4,
        torch_dtype=torch.bfloat16
    )

    pipe.enable_model_cpu_offload()

    save_folder = f"image_from_{model_id}"
    Path(save_folder).mkdir(exist_ok=True)

    ref_image = load_image("girl_1234.jpg")
    for strength in [0.5, 0.6, 0.7, 0.8]:
        generator = torch.Generator().manual_seed(12345)
        image = pipe(
            prompt="A blender rendering artwork. A young woman in a traditional Japanese kimono stands in a serene garden, holding a white fan with a gold handle. The kimono is adorned with a pink and white floral pattern, and the woman's face is lit up with a warm smile. The garden around her is lush with green trees and bushes, and a wooden structure can be seen in the background.",
            image=ref_image,
            width=1024,
            height=1024,
            strength=strength,
            num_inference_steps=50,
            generator=generator,
            guidance_scale=3.5,
        ).images[0]

        file_name = f"strength_{strength}.jpg"
        image.save(Path(save_folder, file_name).as_posix())

if __name__ == "__main__":
    main("xeBlenderFlux_01")

以下の「decorator.py」を使ってVRAM使用量と時間を計測しています。

# pip install pynvml

import functools
import threading
import time
from pynvml import *

def gpu_monitor(interval=1):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            stop_event = threading.Event()

            def monitor_thread():
                max_vram = 0
                device_id = 0
                nvmlInit()
                handle = nvmlDeviceGetHandleByIndex(device_id)
                info = nvmlDeviceGetMemoryInfo(handle)
                total_vram = info.total / (1024**3)
                try:
                    while not stop_event.is_set():
                        info = nvmlDeviceGetMemoryInfo(handle)
                        using_vram = info.used / (1024**3)
                        if using_vram > max_vram:
                            max_vram = using_vram
                        
                        time.sleep(interval)
                except NVMLError as error:
                    print(f"NVML Error: {error}")
                finally:
                    nvmlShutdown()
                    print(f"GPU {device_id} - Used memory: {max_vram:.2f}/{total_vram:.2f} GB")

            monitor = threading.Thread(target=monitor_thread)
            monitor.start()

            try:
                result = func(*args, **kwargs)
            finally:
                stop_event.set()
                monitor.join()

            return result
        return wrapper
    return decorator

def time_monitor(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"time: {(end_time - start_time):.2f} sec")
        return result
    return wrapper