【Diffusers】「CogVideoX」でImage2Video

はじめに

以前「Text2Video」を行いました。
touch-sp.hatenablog.com
今回は「Image2Video」を行いました。

PC環境

Windows 11
RTX 4090 (VRAM 24GB)
CUDA 11.8
Python 3.12

Python環境構築

pip install torch==2.4.1+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install diffusers[torch]
pip install transformers sentencepiece opencv-python

Pythonスクリプト

import torch
from diffusers import CogVideoXImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from decorator import gpu_monitor, time_monitor

@gpu_monitor(interval=0.5)
@time_monitor
def main():
    pipe = CogVideoXImageToVideoPipeline.from_pretrained(
        "THUDM/CogVideoX-5b-I2V",
        torch_dtype=torch.bfloat16
    )
    #pipe.to("cuda")

    # Optionally, enable memory optimizations.
    # If enabling CPU offloading, remember to remove `pipe.to("cuda")` above
    pipe.enable_model_cpu_offload()
    pipe.vae.enable_tiling()

    prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
    image = load_image(
        "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
    )
    video = pipe(image, prompt, use_dynamic_cfg=True)
    export_to_video(video.frames[0], "output.mp4", fps=8)

if __name__ == "__main__":
    main()
# pip install pynvml

import functools
import threading
import time
from pynvml import *

def gpu_monitor(interval=1):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            stop_event = threading.Event()

            def monitor_thread():
                max_vram = 0
                device_id = 0
                nvmlInit()
                handle = nvmlDeviceGetHandleByIndex(device_id)
                info = nvmlDeviceGetMemoryInfo(handle)
                total_vram = info.total / (1024**3)
                try:
                    while not stop_event.is_set():
                        info = nvmlDeviceGetMemoryInfo(handle)
                        using_vram = info.used / (1024**3)
                        if using_vram > max_vram:
                            max_vram = using_vram
                        
                        time.sleep(interval)
                except NVMLError as error:
                    print(f"NVML Error: {error}")
                finally:
                    nvmlShutdown()
                    print(f"GPU {device_id} - Used memory: {max_vram:.2f}/{total_vram:.2f} GB")

            monitor = threading.Thread(target=monitor_thread)
            monitor.start()

            try:
                result = func(*args, **kwargs)
            finally:
                stop_event.set()
                monitor.join()

            return result
        return wrapper
    return decorator

def time_monitor(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"time: {(end_time - start_time):.2f} sec")
        return result
    return wrapper

結果

time: 384.74 sec
GPU 0 - Used memory: 17.17/23.99 GB





このエントリーをはてなブックマークに追加