【Diffusers】FLUX.1-dev と ControlNet をVRAM 16GB以下で使用する

はじめに

FLUX.1-dev で ControlNet が使えるようになったのでさっそく使ってみました。

VRAM使用量を16GB以下に抑えるために「optimum-quanto」を使いました。

最終的には1024x1024の画像生成がVRAM 16GB以下で可能になりました。

Python環境構築

pip install torch==2.4.0+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install git+https://github.com/huggingface/diffusers
pip install accelerate transformers protobuf sentencepiece
pip install optimum-quanto

量子化

「tranformer」「text_encoder_2」「controlnet」を量子化して保存しました。

import torch
from diffusers import FluxTransformer2DModel, FluxControlNetModel
from transformers import T5EncoderModel
from optimum.quanto import freeze, qfloat8, quantize, quantization_map
from pathlib import Path
import json

dtype = torch.bfloat16

controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny-alpha"
base_model = "black-forest-labs/FLUX.1-dev"

controlnet = FluxControlNetModel.from_pretrained(
    controlnet_model, 
    torch_dtype=dtype
)
quantize(controlnet, weights=qfloat8)
freeze(controlnet)

save_directory = "fluxcontrolnet_qfloat8"
controlnet.save_pretrained(save_directory)
qmap_name = Path(save_directory, "quanto_qmap.json")
qmap = quantization_map(controlnet)
with open(qmap_name, "w", encoding="utf8") as f:
    json.dump(qmap, f, indent=4)

transformer = FluxTransformer2DModel.from_pretrained(
    base_model,
    subfolder="transformer",
    torch_dtype=dtype
)
quantize(transformer, weights=qfloat8)
freeze(transformer)

save_directory = "fluxtransformer2dmodel_qfloat8"
transformer.save_pretrained(save_directory)
qmap_name = Path(save_directory, "quanto_qmap.json")
qmap = quantization_map(transformer)
with open(qmap_name, "w", encoding="utf8") as f:
    json.dump(qmap, f, indent=4)

text_encoder_2 = T5EncoderModel.from_pretrained(
    base_model,
    subfolder="text_encoder_2",
    torch_dtype=dtype
)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

save_directory = "t5encodermodel_qfloat8"
text_encoder_2.save_pretrained(save_directory)
qmap_name = Path(save_directory, "quanto_qmap.json")
qmap = quantization_map(text_encoder_2)
with open(qmap_name, "w", encoding="utf8") as f:
    json.dump(qmap, f, indent=4)

画像生成

import time
import torch
from diffusers.utils import load_image
from diffusers import FluxControlNetModel, FluxTransformer2DModel, FluxControlNetPipeline
from transformers import T5EncoderModel
from optimum.quanto import QuantizedTransformersModel, QuantizedDiffusersModel

start = time.perf_counter()

bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16

class QuantizedFluxControlNetModel(QuantizedDiffusersModel):
    base_class = FluxControlNetModel

controlnet = QuantizedFluxControlNetModel.from_pretrained(
    "fluxcontrolnet_qfloat8"
).to(dtype=dtype)

class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
    base_class = FluxTransformer2DModel

transformer = QuantizedFluxTransformer2DModel.from_pretrained(
    "fluxtransformer2dmodel_qfloat8"
).to(dtype=dtype)

class QuantizedT5EncoderModelForCausalLM(QuantizedTransformersModel):
    auto_class = T5EncoderModel
    auto_class.from_config = auto_class._from_config

text_encoder_2 = QuantizedT5EncoderModelForCausalLM.from_pretrained(
    "t5encodermodel_qfloat8"
).to(dtype=dtype)

pipe = FluxControlNetPipeline.from_pretrained(
    bfl_repo,
    transformer=transformer,
    text_encoder_2=text_encoder_2,
    controlnet=controlnet,
    torch_dtype=dtype
)

pipe.enable_model_cpu_offload()

control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
prompt = "A girl in city, 25 years old, cool, futuristic"
image = pipe(
    prompt,
    control_image=control_image,
    controlnet_conditioning_scale=0.6,
    num_inference_steps=28,
    guidance_scale=3.5,
).images[0]
image.save("flux_qfloat8.jpg")

end = time.perf_counter()
print(f"time: {(end - start):.2f}sec")

結果

time: 244.63sec

VRAM使用量

1024x1024の画像を作成した時の結果です。

ぎりぎり16GB未満です。