はじめに
FLUX.1-dev で ControlNet が使えるようになったのでさっそく使ってみました。VRAM使用量を16GB以下に抑えるために「optimum-quanto」を使いました。最終的には1024x1024の画像生成がVRAM 16GB以下で可能になりました。
Python環境構築
pip install torch==2.4.0+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install git+https://github.com/huggingface/diffusers
pip install accelerate transformers protobuf sentencepiece
pip install optimum-quanto
量子化
「tranformer」「text_encoder_2」「controlnet」を量子化して保存しました。
import torch
from diffusers import FluxTransformer2DModel, FluxControlNetModel
from transformers import T5EncoderModel
from optimum.quanto import freeze, qfloat8, quantize, quantization_map
from pathlib import Path
import json
dtype = torch.bfloat16
controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny-alpha"
base_model = "black-forest-labs/FLUX.1-dev"
controlnet = FluxControlNetModel.from_pretrained(
controlnet_model,
torch_dtype=dtype
)
quantize(controlnet, weights=qfloat8)
freeze(controlnet)
save_directory = "fluxcontrolnet_qfloat8"
controlnet.save_pretrained(save_directory)
qmap_name = Path(save_directory, "quanto_qmap.json")
qmap = quantization_map(controlnet)
with open(qmap_name, "w", encoding="utf8") as f:
json.dump(qmap, f, indent=4)
transformer = FluxTransformer2DModel.from_pretrained(
base_model,
subfolder="transformer",
torch_dtype=dtype
)
quantize(transformer, weights=qfloat8)
freeze(transformer)
save_directory = "fluxtransformer2dmodel_qfloat8"
transformer.save_pretrained(save_directory)
qmap_name = Path(save_directory, "quanto_qmap.json")
qmap = quantization_map(transformer)
with open(qmap_name, "w", encoding="utf8") as f:
json.dump(qmap, f, indent=4)
text_encoder_2 = T5EncoderModel.from_pretrained(
base_model,
subfolder="text_encoder_2",
torch_dtype=dtype
)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
save_directory = "t5encodermodel_qfloat8"
text_encoder_2.save_pretrained(save_directory)
qmap_name = Path(save_directory, "quanto_qmap.json")
qmap = quantization_map(text_encoder_2)
with open(qmap_name, "w", encoding="utf8") as f:
json.dump(qmap, f, indent=4)
画像生成
import time
import torch
from diffusers.utils import load_image
from diffusers import FluxControlNetModel, FluxTransformer2DModel, FluxControlNetPipeline
from transformers import T5EncoderModel
from optimum.quanto import QuantizedTransformersModel, QuantizedDiffusersModel
start = time.perf_counter()
bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
class QuantizedFluxControlNetModel(QuantizedDiffusersModel):
base_class = FluxControlNetModel
controlnet = QuantizedFluxControlNetModel.from_pretrained(
"fluxcontrolnet_qfloat8"
).to(dtype=dtype)
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
base_class = FluxTransformer2DModel
transformer = QuantizedFluxTransformer2DModel.from_pretrained(
"fluxtransformer2dmodel_qfloat8"
).to(dtype=dtype)
class QuantizedT5EncoderModelForCausalLM(QuantizedTransformersModel):
auto_class = T5EncoderModel
auto_class.from_config = auto_class._from_config
text_encoder_2 = QuantizedT5EncoderModelForCausalLM.from_pretrained(
"t5encodermodel_qfloat8"
).to(dtype=dtype)
pipe = FluxControlNetPipeline.from_pretrained(
bfl_repo,
transformer=transformer,
text_encoder_2=text_encoder_2,
controlnet=controlnet,
torch_dtype=dtype
)
pipe.enable_model_cpu_offload()
control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
prompt = "A girl in city, 25 years old, cool, futuristic"
image = pipe(
prompt,
control_image=control_image,
controlnet_conditioning_scale=0.6,
num_inference_steps=28,
guidance_scale=3.5,
).images[0]
image.save("flux_qfloat8.jpg")
end = time.perf_counter()
print(f"time: {(end - start):.2f}sec")
結果
time: 244.63sec
VRAM使用量
1024x1024の画像を作成した時の結果です。ぎりぎり16GB未満です。