【Diffusers】NVIDIAが開発したSANAで画像生成してみる

はじめに

使用したのはこちらです。
github.com
今回はDiffusersを使って実行しました。

PC環境

Windows 11
RTX 3080 Laptop (VRAM 16GB)
Python 3.12
CUDA 11.8

Python環境構築

pip install torch==2.5.1+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install git+https://github.com/huggingface/diffusers
pip install transformers accelerate beautifulsoup4 ftfy

実行

PAGのありなしで画像を生成しています。

PAGについてはこちらを見て下さい。
touch-sp.hatenablog.com

No PAG

Pythonスクリプト

import torch
from diffusers import SanaPipeline
from decorator import gpu_monitor, time_monitor

@gpu_monitor(interval=0.5)
@time_monitor
def main():
    pipe = SanaPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
    variant="bf16",
    torch_dtype=torch.bfloat16,
    ).to("cuda")

    pipe.text_encoder.to(torch.bfloat16)
    pipe.vae.to(torch.bfloat16)

    prompt = 'a cyberpunk cat with a neon sign that says "Sana"'
    image = pipe(
        prompt=prompt,
        guidance_scale=5.0,
        num_inference_steps=20,
        generator=torch.Generator(device="cuda").manual_seed(42),
    )[0]

    image[0].save('no_pag.jpg')

if __name__ == "__main__":
    main()

結果

time: 15.46 sec
GPU 0 - Used memory: 12.07/16.00 GB


With PAG

Pythonスクリプト

import torch
from diffusers import SanaPAGPipeline
from decorator import gpu_monitor, time_monitor

@gpu_monitor(interval=0.5)
@time_monitor
def main():
    pipe = SanaPAGPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
    variant="bf16",
    torch_dtype=torch.bfloat16,
    pag_applied_layers="transformer_blocks.8",
    ).to("cuda")

    pipe.text_encoder.to(torch.bfloat16)
    pipe.vae.to(torch.bfloat16)

    prompt = 'a cyberpunk cat with a neon sign that says "Sana"'
    image = pipe(
        prompt=prompt,
        guidance_scale=5.0,
        pag_scale=2.0,
        num_inference_steps=20,
        generator=torch.Generator(device="cuda").manual_seed(42),
    )[0]

    image[0].save('with_pag.jpg')

if __name__ == "__main__":
    main()

結果

time: 18.52 sec
GPU 0 - Used memory: 11.83/16.00 GB


その他

ベンチマークはこちらで記述したスクリプトで行いました。
touch-sp.hatenablog.com