ELYZAが公開している「Llama-3-ELYZA-JP-8B」をGradioを使ってローカルで使用する

はじめに

前回「CyberAgentLM3-22B-Chat」で同じことをしました。
touch-sp.hatenablog.com
今回は「Llama-3-ELYZA-JP-8B」です。

ELYZAの「Llama-3-ELYZA-JP-70B」はGPT-4を上回る日本語性能と言われています。

今回使用したのはそれよりもはるかに小規模な「Llama-3-ELYZA-JP-8B」です。質問の回答としては「CyberAgentLM3-22B-Chat」に劣る印象です。やはりモデルが大きいほど高い性能が出せるんでしょうね。

モデルの量子化

8bit量子化を行いました。

from transformers import AutoModelForCausalLM, BitsAndBytesConfig

# model was downloaded from https://huggingface.co/elyza/Llama-3-ELYZA-JP-8B
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
    "Llama-3-ELYZA-JP-8B",
    quantization_config=quantization_config
)
model.save_pretrained("Llama-3-ELYZA-JP-8B-8bit")

Gradioで実行

import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import torch

system_prompt_text = "あなたは誠実で優秀な日本人のアシスタントです。特に指示が無い場合は、常に日本語で回答してください。"
init = {
    "role": "system",
    "content": system_prompt_text,
}

model = AutoModelForCausalLM.from_pretrained(
    "Llama-3-ELYZA-JP-8B-8bit",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("Llama-3-ELYZA-JP-8B")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

def call_llm(
    message: str,
    history: list[dict],
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    history_openai_format = []
    if len(history) == 0:
        history_openai_format.append(init)
        history_openai_format.append({"role": "user", "content": message})
    else:
        history_openai_format.append(init)
        for human, assistant in history:
            history_openai_format.append({"role": "user", "content": human})
            history_openai_format.append({"role": "assistant", "content": assistant})
        history_openai_format.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(
        history_openai_format,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    generation_kwargs = dict(
        inputs=input_ids,
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        pad_token_id=tokenizer.eos_token_id
    )

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    generated_text = ""
    for new_text in streamer:
        generated_text += new_text
        yield generated_text

def run():
    chatbot = gr.Chatbot(
        elem_id="chatbot",
        scale=1,
        show_copy_button=True,
        height="70%",
        layout="panel",
    )
    with gr.Blocks(fill_height=True) as demo:
        gr.Markdown("# Llama-3-ELYZA-JP-8B")
        gr.ChatInterface(
            fn=call_llm,
            stop_btn="Stop Generation",
            cache_examples=False,
            multimodal=False,
            chatbot=chatbot,
            additional_inputs_accordion=gr.Accordion(
                label="Parameters", open=False, render=False
            ),
            additional_inputs=[
                gr.Slider(
                    minimum=1,
                    maximum=4096,
                    step=1,
                    value=1200,
                    label="Max tokens",
                    visible=True,
                    render=False,
                ),
                gr.Slider(
                    minimum=0,
                    maximum=1,
                    step=0.1,
                    value=0.6,
                    label="Temperature",
                    visible=True,
                    render=False,
                ),
                gr.Slider(
                    minimum=0,
                    maximum=1,
                    step=0.1,
                    value=0.9,
                    label="Top-p",
                    visible=True,
                    render=False,
                ),
            ],
        )
    demo.launch(share=False)

if __name__ == "__main__":
    run()

警告

二つの警告が出ました。

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



一つ目の警告は「generation_kwargs」内に「pad_token_id=tokenizer.eos_token_id」を追加することで回避できました。

二つ目の警告の消し方はわかりませんでした。