【Diffusers】anzu_flux をなんとか VRAM 12GB 未満で動かす

huggingface.co

はじめに

前回テキストエンコーダーを8bitに量子化して何とか16GB未満(実際は13GB強)で動作させることができました。
touch-sp.hatenablog.com
今回はトランスフォーマーを4bitに量子化して12GB未満で動かしてみました。

結果

成功です。

PC環境

Windows 11
Python 3.12
CUDA 11.8

Python環境構築

pip install torch==2.4.0+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install git+https://github.com/huggingface/diffusers.git@quantization-config
pip install accelerate transformers protobuf sentencepiece bitsandbytes
pip uninstall numpy
pip install numpy==1.26.4

方法

「anzu_flux_Mix_beta01_.safetensors」をダウンロードしたあといったんDiffusersフォーマットに変換しました。

このモデルの詳細はわかりませんがトランスフォーマーの部分のみ使用しています。

もしテキストエンコーダーもファインチューニングされていたら、一部利用していないことになります。

import torch
from diffusers import FluxTransformer2DModel, FluxPipeline

dtype = torch.bfloat16

transformer = FluxTransformer2DModel.from_single_file(
    "anzu_flux_Mix_beta01_.safetensors",
    torch_dtype=dtype
)

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    transformer=transformer,
    torch_dtype=dtype
)

pipe.save_pretrained("anzu_flux_mix_beta01")



変換後は以下を実行しました。

import torch 
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel, FluxPipeline

model_id = "anzu_flux_mix_beta01"

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model_nf4 = FluxTransformer2DModel.from_pretrained(
    model_id, subfolder="transformer",
    quantization_config=nf4_config,
    torch_dtype=torch.bfloat16
)

pipe = FluxPipeline.from_pretrained(
    model_id,
    transformer=model_nf4,
    torch_dtype=torch.bfloat16
)

pipe.enable_model_cpu_offload()

generator = torch.Generator().manual_seed(123)
image = pipe(
    prompt="A photorealistic portrait of a young Japanese woman with long black hair and natural makeup, wearing a casual white blouse, sitting in a modern Tokyo cafe with soft window light",
    width=1360,
    height=768,
    num_inference_steps=50,
    generator=generator,
    guidance_scale=3.5,
).images[0]

image.save("woman.jpg")

作成された画像