【FLUX.1-dev】量子化を行う時にDiffusersでは「bitsandbytes」より「optimum-quanto」を使うことが推奨されていたのでさっそく使ってみました

はじめに

前回「bitsandbytes」を使って量子化を行いました。
touch-sp.hatenablog.com
公式ページでは「optimum-quanto」を使うことが推奨されていたのでさっそく試してみました。

pip install optimum-quanto



「optimum-quanto」を使って量子化したモデルを保存する方法はこちらを参照しました。
github.com
github.com
qiita.com

Pythonスクリプト

量子化を行ってそれを保存するスクリプト

import torch
from diffusers import FluxTransformer2DModel
from transformers import T5EncoderModel
from optimum.quanto import freeze, qfloat8, quantize, quantization_map
from pathlib import Path
import json

bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16

transformer = FluxTransformer2DModel.from_pretrained(
    bfl_repo,
    subfolder="transformer",
    torch_dtype=dtype
)
quantize(transformer, weights=qfloat8)
freeze(transformer)

save_directory = "fluxtransformer2dmodel_qfloat8"
transformer.save_pretrained(save_directory)
qmap_name = Path(save_directory, "quanto_qmap.json")
qmap = quantization_map(transformer)
with open(qmap_name, "w", encoding="utf8") as f:
    json.dump(qmap, f, indent=4)

text_encoder_2 = T5EncoderModel.from_pretrained(
    bfl_repo,
    subfolder="text_encoder_2",
    torch_dtype=dtype
)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

save_directory = "t5encodermodel_qfloat8"
text_encoder_2.save_pretrained(save_directory)
qmap_name = Path(save_directory, "quanto_qmap.json")
qmap = quantization_map(text_encoder_2)
with open(qmap_name, "w", encoding="utf8") as f:
    json.dump(qmap, f, indent=4)

保存されたモデルを読み込んでText2Imageの実行

import time
import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel
from optimum.quanto import QuantizedTransformersModel, QuantizedDiffusersModel

start = time.perf_counter()

bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16

class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
    base_class = FluxTransformer2DModel

transformer = QuantizedFluxTransformer2DModel.from_pretrained(
    "fluxtransformer2dmodel_qfloat8"
).to(dtype=dtype)

class QuantizedT5EncoderModelForCausalLM(QuantizedTransformersModel):
    auto_class = T5EncoderModel
    auto_class.from_config = auto_class._from_config

text_encoder_2 = QuantizedT5EncoderModelForCausalLM.from_pretrained(
    "t5encodermodel_qfloat8"
).to(dtype=dtype)

pipe = FluxPipeline.from_pretrained(
    bfl_repo,
    transformer=transformer,
    text_encoder_2=text_encoder_2,
    torch_dtype=dtype
)

pipe.enable_model_cpu_offload()

prompt = "an insect robot preparing a delicious meal, anime style"
out = pipe(
    prompt=prompt,
    guidance_scale=3.5,
    height=768,
    width=1360,
    num_inference_steps=50,
    generator=torch.manual_seed(0)
).images[0]
out.save("dev_qfloat8_result.jpg")

end = time.perf_counter()
print(f"time: {(end - start):.2f}sec")

結果

time: 122.17sec

time: 119.44sec

上下に黒い帯が入るのは量子化の影響でしょうか?



このエントリーをはてなブックマークに追加