はじめに
前回「bitsandbytes」を使って量子化を行いました。touch-sp.hatenablog.com
公式ページでは「optimum-quanto」を使うことが推奨されていたのでさっそく試してみました。
pip install optimum-quanto
「optimum-quanto」を使って量子化したモデルを保存する方法はこちらを参照しました。
github.com
github.com
qiita.com
Pythonスクリプト
量子化を行ってそれを保存するスクリプト
import torch from diffusers import FluxTransformer2DModel from transformers import T5EncoderModel from optimum.quanto import freeze, qfloat8, quantize, quantization_map from pathlib import Path import json bfl_repo = "black-forest-labs/FLUX.1-dev" dtype = torch.bfloat16 transformer = FluxTransformer2DModel.from_pretrained( bfl_repo, subfolder="transformer", torch_dtype=dtype ) quantize(transformer, weights=qfloat8) freeze(transformer) save_directory = "fluxtransformer2dmodel_qfloat8" transformer.save_pretrained(save_directory) qmap_name = Path(save_directory, "quanto_qmap.json") qmap = quantization_map(transformer) with open(qmap_name, "w", encoding="utf8") as f: json.dump(qmap, f, indent=4) text_encoder_2 = T5EncoderModel.from_pretrained( bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype ) quantize(text_encoder_2, weights=qfloat8) freeze(text_encoder_2) save_directory = "t5encodermodel_qfloat8" text_encoder_2.save_pretrained(save_directory) qmap_name = Path(save_directory, "quanto_qmap.json") qmap = quantization_map(text_encoder_2) with open(qmap_name, "w", encoding="utf8") as f: json.dump(qmap, f, indent=4)
保存されたモデルを読み込んでText2Imageの実行
import time import torch from diffusers import FluxTransformer2DModel, FluxPipeline from transformers import T5EncoderModel from optimum.quanto import QuantizedTransformersModel, QuantizedDiffusersModel start = time.perf_counter() bfl_repo = "black-forest-labs/FLUX.1-dev" dtype = torch.bfloat16 class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel): base_class = FluxTransformer2DModel transformer = QuantizedFluxTransformer2DModel.from_pretrained( "fluxtransformer2dmodel_qfloat8" ).to(dtype=dtype) class QuantizedT5EncoderModelForCausalLM(QuantizedTransformersModel): auto_class = T5EncoderModel auto_class.from_config = auto_class._from_config text_encoder_2 = QuantizedT5EncoderModelForCausalLM.from_pretrained( "t5encodermodel_qfloat8" ).to(dtype=dtype) pipe = FluxPipeline.from_pretrained( bfl_repo, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=dtype ) pipe.enable_model_cpu_offload() prompt = "an insect robot preparing a delicious meal, anime style" out = pipe( prompt=prompt, guidance_scale=3.5, height=768, width=1360, num_inference_steps=50, generator=torch.manual_seed(0) ).images[0] out.save("dev_qfloat8_result.jpg") end = time.perf_counter() print(f"time: {(end - start):.2f}sec")
結果
time: 122.17sec
time: 119.44sec
上下に黒い帯が入るのは量子化の影響でしょうか?