Gradio 5.0以上でチャットボットを作る

はじめに

Gradio 5.0からチャットボットの作り方が大きく変わりました。
今回、実際に作ってみました。

Pythonスクリプト

import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import torch

model_name = "tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.2"

system_prompt_text = "あなたは誠実で優秀な日本人のアシスタントです。"
init = {
    "role": "system",
    "metadata": {"title": None},
    "content": system_prompt_text,
}

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

def user(
    message: str,
    history: list[dict]
):
    if len(history)==0:  
        history.insert(0, init)
    history.append(
        {
            "role": "user", 
            "metadata": {"title": None},
            "content": message
        }
    )
    return "", history

def bot(
    history: list[dict]
):
    input_tensors = tokenizer.apply_chat_template(
        history,
        add_generation_prompt=True,
        return_tensors="pt",
         return_dict=True
    ).to(model.device)

    input_ids = input_tensors["input_ids"]
    attention_mask = input_tensors["attention_mask"]

    generation_kwargs = dict(
        inputs=input_ids,
        attention_mask=attention_mask,
        streamer=streamer,
        max_new_tokens=512,
        temperature=0.6,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id
    )
    history.append({"role": "assistant", "content": ""})

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    for new_text in streamer:
        history[-1]["content"] += new_text
        yield history

with gr.Blocks() as demo:
    chatbot = gr.Chatbot(type="messages")
    msg = gr.Textbox()
    clear = gr.Button("新しいチャットを開始")
    
    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

demo.launch()

実際の画面


回答が短いのは使用したモデルによるものだと思います。