FLUX.1-devの量子化を深堀りしてみる

RTX 4090 (VRAM 24GB)で検証しています。

transformerのみを量子化

GPU 0 - Used memory: 10.61/23.99 GB
time: 99.07 sec

text_encoder_2のみを量子化

GPU 0 - Used memory: 9.32/23.99 GB
time: 184.73 sec

両方を量子化

GPU 0 - Used memory: 15.14/23.99 GB
time: 50.56 sec

なぜかVRAM使用量が増えます。

量子化することによってフルにVRAMが使用できるようになったからでしょうか。

生成速度はかなり速いです。

この速度を維持しつつ、VRAM使用量を減らす方法はこちらを見て下さい。
touch-sp.hatenablog.com

Pythonスクリプト

transformerのみを量子化

import torch 
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel, FluxPipeline
from decorator import gpu_monitor, time_monitor

@time_monitor
@gpu_monitor(interval=0.5)
def main(model_id:str) -> None:

    nf4_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    model_nf4 = FluxTransformer2DModel.from_pretrained(
        model_id,
        subfolder="transformer",
        quantization_config=nf4_config,
        torch_dtype=torch.bfloat16
    )

    pipe = FluxPipeline.from_pretrained(
        model_id,
        transformer=model_nf4,
        torch_dtype=torch.bfloat16
    )

    pipe.enable_model_cpu_offload()

    generator = torch.Generator().manual_seed(123)
    image = pipe(
        prompt="A photorealistic portrait of a young Japanese woman with long black hair and natural makeup, wearing a casual white blouse, sitting in a modern Tokyo cafe with soft window light",
        width=1360,
        height=768,
        num_inference_steps=50,
        generator=generator,
        guidance_scale=3.5,
    ).images[0]

    image.save("transformer_4bit.jpg")

if __name__ == "__main__":
    main("FLUX.1-dev")

text_encoder_2のみを量子化

import torch 
from diffusers import FluxPipeline
from transformers import BitsAndBytesConfig, T5EncoderModel
from decorator import gpu_monitor, time_monitor

@time_monitor
@gpu_monitor(interval=0.5)
def main(model_id:str) -> None:

    textencoder_config = BitsAndBytesConfig(load_in_4bit=True)
    text_encoder_2 = T5EncoderModel.from_pretrained(
        model_id,
        subfolder="text_encoder_2",
        quantization_config=textencoder_config,
        torch_dtype=torch.bfloat16
    )

    pipe = FluxPipeline.from_pretrained(
        model_id,
        text_encoder_2=text_encoder_2,
        torch_dtype=torch.bfloat16,
        device_map="balanced"
    )

    generator = torch.Generator().manual_seed(123)
    image = pipe(
        prompt="A photorealistic portrait of a young Japanese woman with long black hair and natural makeup, wearing a casual white blouse, sitting in a modern Tokyo cafe with soft window light",
        width=1360,
        height=768,
        num_inference_steps=50,
        generator=generator,
        guidance_scale=3.5,
    ).images[0]

    image.save("textencoder_4bit.jpg")

if __name__ == "__main__":
    main("FLUX.1-dev")

両方を量子化

import torch 
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel
from diffusers import BitsAndBytesConfig as diffusers_config
from transformers import BitsAndBytesConfig as transformers_config
from decorator import gpu_monitor, time_monitor

@time_monitor
@gpu_monitor(interval=0.5)
def main(model_id:str) -> None:

    transformer_config = diffusers_config(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    transformer = FluxTransformer2DModel.from_pretrained(
        model_id,
        subfolder="transformer",
        quantization_config=transformer_config,
        torch_dtype=torch.bfloat16
    )
    textencoder_config = transformers_config(load_in_4bit=True)
    text_encoder_2 = T5EncoderModel.from_pretrained(
        model_id,
        subfolder="text_encoder_2",
        quantization_config=textencoder_config,
        torch_dtype=torch.bfloat16
    )

    pipe = FluxPipeline.from_pretrained(
        model_id,
        transformer=transformer,
        text_encoder_2=text_encoder_2,
        torch_dtype=torch.bfloat16,
        device_map="balanced"
    )

    generator = torch.Generator().manual_seed(123)
    image = pipe(
        prompt="A photorealistic portrait of a young Japanese woman with long black hair and natural makeup, wearing a casual white blouse, sitting in a modern Tokyo cafe with soft window light",
        width=1360,
        height=768,
        num_inference_steps=50,
        generator=generator,
        guidance_scale=3.5,
    ).images[0]

    image.save("both_4bit.jpg")

if __name__ == "__main__":
    main("FLUX.1-dev")