【Diffusers】FLUX.1-devを「VRAM 12GB以内」でかつ「高速」に動かす方法

はじめに

前回量子化について調べてみました。
touch-sp.hatenablog.com
今回、生成過程を分割することで「VRAM 12GB以内」かつ「高速」に動かすことができました。

結果

RTX 4090 (VRAM 24GB)で測定しています。

torch.cuda.max_memory_allocated: 6.58 GB
torch.cuda.max_memory_allocated: 6.76 GB
GPU 0 - Used memory: 9.01/23.99 GB
time: 52.11 sec

「torch.cuda.max_memory_allocated」でVRAM使用量を測定すると8GB未満ですが実際は9GB程度使用しています。

3行目は「pynvml」ライブラリで測定したものです。

Pythonスクリプト

ポイントは先述した通り、過程を分割することです。

import gc
import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel
from diffusers import BitsAndBytesConfig as diffusers_config
from transformers import BitsAndBytesConfig as transformers_config

from decorator import gpu_monitor, time_monitor

def flush():
    gc.collect()
    torch.cuda.empty_cache()

@time_monitor
@gpu_monitor(interval=0.5)
def main():
    base_model = 'FLUX.1-dev'
    width=1360
    height=768
    
    textencoder_config = transformers_config(load_in_4bit=True)
    text_encoder_2 = T5EncoderModel.from_pretrained(
        base_model,
        subfolder="text_encoder_2",
        quantization_config=textencoder_config,
        torch_dtype=torch.bfloat16
    )

    pipeline = FluxPipeline.from_pretrained(
        base_model,
        text_encoder_2=text_encoder_2,
        transformer=None,
        vae=None,
        device_map="balanced"
    )

    prompt = "A photorealistic portrait of a young Japanese woman with long black hair and natural makeup, wearing a casual white blouse, sitting in a modern Tokyo cafe with soft window light"

    with torch.no_grad():
        prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
            prompt=prompt,
            prompt_2=None,
            #max_sequence_length=256
        )
    
    print(f"torch.cuda.max_memory_allocated: {torch.cuda.max_memory_allocated()/ 1024**3:.2f} GB")
    del text_encoder_2
    del pipeline
    flush()

    transformer_config = diffusers_config(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    transformer = FluxTransformer2DModel.from_pretrained(
        base_model,
        subfolder="transformer",
        quantization_config=transformer_config,
        torch_dtype=torch.bfloat16
    )
    pipeline = FluxPipeline.from_pretrained(
        base_model,
        transformer=transformer,
        text_encoder=None,
        text_encoder_2=None,
        tokenizer=None,
        tokenizer_2=None,
        torch_dtype=torch.bfloat16
    )

    pipeline.enable_model_cpu_offload()
    generator = torch.Generator().manual_seed(123)
    image = pipeline(
        prompt_embeds=prompt_embeds.bfloat16(),
        pooled_prompt_embeds=pooled_prompt_embeds.bfloat16(),
        width=width,
        height=height,
        num_inference_steps=50, 
        guidance_scale=3.5,
        generator=generator
    ).images[0]

    image.save("woman.jpg")    
    print(f"torch.cuda.max_memory_allocated: {torch.cuda.max_memory_allocated()/ 1024**3:.2f} GB")
    
if __name__ == "__main__":
    main()

VAEの部分も分割することは可能ですが、それをやっても8GB未満にはできませんでした。