【Diffusers】Stable Diffusion 3.5 Mediumを使ってみる

はじめに

Stable Diffusion 3.5 Mediumが公開されたのでVRAM使用量を調べるため色々な方法で実行してみました。

使用したPC

こちらのPCを使いました。

Windows 11
RTX 3080 Laptop (VRAM 16GB)
CUDA 11.8
Python 3.12

実行

Method-1: to("cuda")

VRAM 16GBを超えているので使い物になりませんでした。
1枚の画像を作成するのに30分もかかっています。

torch.cuda.max_memory_allocated: 17.60 GB
time: 1774.59 sec
GPU 0 - Used memory: 15.97/16.00 GB


RTX 4090 (VRAM 24GB)で実行した結果が以下です。

torch.cuda.max_memory_allocated: 17.59 GB
time: 23.17 sec
GPU 0 - Used memory: 20.87/23.99 GB

Method-2: enable_model_cpu_offload()

torch.cuda.max_memory_allocated: 10.19 GB
time: 77.62 sec
GPU 0 - Used memory: 10.51/16.00 GB


Method-3: enable_sequential_cpu_offload()

torch.cuda.max_memory_allocated: 2.82 GB
time: 156.75 sec
GPU 0 - Used memory: 2.25/16.00 GB


Method-4: Method-2 + Method-3

Method-3と変わらない結果でした。
Method-2とMethod-3を組み合わせる意義はないと思います。

torch.cuda.max_memory_allocated: 2.82 GB
time: 162.20 sec
GPU 0 - Used memory: 2.25/16.00 GB


Method-5: 量子化 + Method-1

量子化するとVRAM 16GB以下に抑えられました。
作成画像ががらりと変わってしまいました。

torch.cuda.max_memory_allocated: 10.19 GB
time: 70.48 sec
GPU 0 - Used memory: 9.88/16.00 GB


Method-6: 量子化 + Method-2

Method-2と比較して量子化で少し速くなる程度でした。
作成画像ががらりと変わってしまいました。

torch.cuda.max_memory_allocated: 10.19 GB
time: 70.48 sec
GPU 0 - Used memory: 9.88/16.00 GB


Method-7: 量子化 + Method-3

以下のエラーで実行できませんでした。

raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
ValueError: Blockwise quantization only supports 16/32-bit floats, but got torch.uint8

Method-8: 行程を分割 + Method-1

torch.cuda.max_memory_allocated: 10.48 GB
torch.cuda.max_memory_allocated: 10.48 GB
time: 62.20 sec
GPU 0 - Used memory: 11.11/16.00 GB


結論

今回得られた結果から導き出せる結論は以下の通りです。
PC環境が変われば結論も変わってくるかもしれません。

VRAM 12GB未満:Method-3がお勧め
VRAM 12~18GB:Method-8がお勧め(あるいはMethod-2)
VRAM 18GB以上:Method-1がお勧め

Pythonスクリプト

すべてのPythonスクリプトを載せておきます。
ベンチマークはこちらで記述したスクリプトで行いました。
touch-sp.hatenablog.com

Method-1

import torch
from diffusers import StableDiffusion3Pipeline
from decorator import gpu_monitor, time_monitor

@gpu_monitor(interval=0.5)
@time_monitor
def main():
    pipe = StableDiffusion3Pipeline.from_pretrained(
        "stabilityai/stable-diffusion-3.5-medium",
        torch_dtype=torch.bfloat16
    )
    pipe.to("cuda")

    image = pipe(
        "A capybara holding a sign that reads Hello World",
        num_inference_steps=40,
        guidance_scale=4.5,
        generator=torch.manual_seed(42)
    ).images[0]

    image.save("capybara_1.jpg")

    print(f"torch.cuda.max_memory_allocated: {torch.cuda.max_memory_allocated()/ 1024**3:.2f} GB")

if __name__ == "__main__":
    main()

Method-2

import torch
from diffusers import StableDiffusion3Pipeline
from decorator import gpu_monitor, time_monitor

@gpu_monitor(interval=0.5)
@time_monitor
def main():
    pipe = StableDiffusion3Pipeline.from_pretrained(
        "stabilityai/stable-diffusion-3.5-medium",
        torch_dtype=torch.bfloat16
    )
    pipe.enable_model_cpu_offload()

    image = pipe(
        "A capybara holding a sign that reads Hello World",
        num_inference_steps=40,
        guidance_scale=4.5,
        generator=torch.manual_seed(42)
    ).images[0]

    image.save("capybara_2.jpg")

    print(f"torch.cuda.max_memory_allocated: {torch.cuda.max_memory_allocated()/ 1024**3:.2f} GB")

if __name__ == "__main__":
    main()

Method-3

import torch
from diffusers import StableDiffusion3Pipeline
from decorator import gpu_monitor, time_monitor

@gpu_monitor(interval=0.5)
@time_monitor
def main():
    pipe = StableDiffusion3Pipeline.from_pretrained(
        "stabilityai/stable-diffusion-3.5-medium",
        torch_dtype=torch.bfloat16
    )
    pipe.enable_sequential_cpu_offload()

    image = pipe(
        "A capybara holding a sign that reads Hello World",
        num_inference_steps=40,
        guidance_scale=4.5,
        generator=torch.manual_seed(42)
    ).images[0]

    image.save("capybara_3.jpg")

    print(f"torch.cuda.max_memory_allocated: {torch.cuda.max_memory_allocated()/ 1024**3:.2f} GB")

if __name__ == "__main__":
    main()

Method-4

import torch
from diffusers import StableDiffusion3Pipeline
from decorator import gpu_monitor, time_monitor

@gpu_monitor(interval=0.5)
@time_monitor
def main():
    pipe = StableDiffusion3Pipeline.from_pretrained(
        "stabilityai/stable-diffusion-3.5-medium",
        torch_dtype=torch.bfloat16
    )
    pipe.enable_model_cpu_offload()
    pipe.enable_sequential_cpu_offload()

    image = pipe(
        "A capybara holding a sign that reads Hello World",
        num_inference_steps=40,
        guidance_scale=4.5,
        generator=torch.manual_seed(42)
    ).images[0]

    image.save("capybara_4.jpg")

    print(f"torch.cuda.max_memory_allocated: {torch.cuda.max_memory_allocated()/ 1024**3:.2f} GB")

if __name__ == "__main__":
    main()

Method-5

import torch
from diffusers import StableDiffusion3Pipeline, BitsAndBytesConfig, SD3Transformer2DModel
from decorator import gpu_monitor, time_monitor

@gpu_monitor(interval=0.5)
@time_monitor
def main():
    model_id =  "stabilityai/stable-diffusion-3.5-medium"

    nf4_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    model_nf4 = SD3Transformer2DModel.from_pretrained(
        model_id,
        subfolder="transformer",
        quantization_config=nf4_config,
        torch_dtype=torch.bfloat16
    )

    pipe = StableDiffusion3Pipeline.from_pretrained(
        model_id,
        transformer=model_nf4,
        torch_dtype=torch.bfloat16
    )
    pipe.to("cuda")
    
    image = pipe(
        "A capybara holding a sign that reads Hello World",
        num_inference_steps=40,
        guidance_scale=4.5,
        generator=torch.manual_seed(42)
    ).images[0]

    image.save("capybara_5.jpg")

    print(f"torch.cuda.max_memory_allocated: {torch.cuda.max_memory_allocated()/ 1024**3:.2f} GB")

if __name__ == "__main__":
    main()

Method-6

import torch
from diffusers import StableDiffusion3Pipeline, BitsAndBytesConfig, SD3Transformer2DModel
from decorator import gpu_monitor, time_monitor

@gpu_monitor(interval=0.5)
@time_monitor
def main():
    model_id =  "stabilityai/stable-diffusion-3.5-medium"

    nf4_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    model_nf4 = SD3Transformer2DModel.from_pretrained(
        model_id,
        subfolder="transformer",
        quantization_config=nf4_config,
        torch_dtype=torch.bfloat16
    )

    pipe = StableDiffusion3Pipeline.from_pretrained(
        model_id,
        transformer=model_nf4,
        torch_dtype=torch.bfloat16
    )
    pipe.enable_model_cpu_offload()
    
    image = pipe(
        "A capybara holding a sign that reads Hello World",
        num_inference_steps=40,
        guidance_scale=4.5,
        generator=torch.manual_seed(42)
    ).images[0]

    image.save("capybara_6.jpg")

    print(f"torch.cuda.max_memory_allocated: {torch.cuda.max_memory_allocated()/ 1024**3:.2f} GB")

if __name__ == "__main__":
    main()

Method-7

import torch
from diffusers import StableDiffusion3Pipeline, BitsAndBytesConfig, SD3Transformer2DModel
from decorator import gpu_monitor, time_monitor

@gpu_monitor(interval=0.5)
@time_monitor
def main():
    model_id =  "stabilityai/stable-diffusion-3.5-medium"

    nf4_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    model_nf4 = SD3Transformer2DModel.from_pretrained(
        model_id,
        subfolder="transformer",
        quantization_config=nf4_config,
        torch_dtype=torch.bfloat16
    )

    pipe = StableDiffusion3Pipeline.from_pretrained(
        model_id,
        transformer=model_nf4,
        torch_dtype=torch.bfloat16
    )
    pipe.enable_sequential_cpu_offload()
    
    image = pipe(
        "A capybara holding a sign that reads Hello World",
        num_inference_steps=40,
        guidance_scale=4.5,
        generator=torch.manual_seed(42)
    ).images[0]

    image.save("capybara_7.jpg")

    print(f"torch.cuda.max_memory_allocated: {torch.cuda.max_memory_allocated()/ 1024**3:.2f} GB")

if __name__ == "__main__":
    main()

Method-8

import gc
import torch
from diffusers import StableDiffusion3Pipeline
from decorator import gpu_monitor, time_monitor

def flush():
    gc.collect()
    torch.cuda.empty_cache()

@gpu_monitor(interval=0.5)
@time_monitor
def main():
    model_id =  "stabilityai/stable-diffusion-3.5-medium"

    pipe = StableDiffusion3Pipeline.from_pretrained(
        model_id,
        transformer=None,
        vae=None,
        torch_dtype=torch.bfloat16
    )
    pipe.to("cuda")

    with torch.no_grad():
        prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
            prompt="A capybara holding a sign that reads Hello World",
            prompt_2=None,
            prompt_3=None
        )
    
    print(f"torch.cuda.max_memory_allocated: {torch.cuda.max_memory_allocated()/ 1024**3:.2f} GB")
    del pipe
    flush()

    pipe = StableDiffusion3Pipeline.from_pretrained(
    model_id,
    text_encoder=None,
    text_encoder_2=None,
    text_encoder_3=None,
    tokenizer=None,
    tokenizer_2=None,
    tokenizer_3=None,
    torch_dtype=torch.bfloat16
    )
    pipe.to("cuda")

    image = pipe(
        prompt_embeds=prompt_embeds.bfloat16(),
        pooled_prompt_embeds=pooled_prompt_embeds.bfloat16(),
        negative_prompt_embeds=negative_prompt_embeds.bfloat16(),
        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.bfloat16(),
        num_inference_steps=40,
        guidance_scale=4.5,
        generator=torch.manual_seed(42)
    ).images[0]

    image.save("capybara_8.jpg")

    print(f"torch.cuda.max_memory_allocated: {torch.cuda.max_memory_allocated()/ 1024**3:.2f} GB")
    
if __name__ == "__main__":
    main()





このエントリーをはてなブックマークに追加