上級

エージェント向けLlama 3のファインチューニング

カスタム関数呼び出しデータセットを使用したLlama 3 70Bのファインチューニングに関する包括的なガイドです。

55分
Llama 3PyTorch

エージェント向けLlama 3のファインチューニング#

効率的なトレーニング技術を使用して、関数呼び出しやエージェントタスク向けにLlama 3 70Bをファインチューニングする方法を学びます。

エージェント向けにファインチューニングする理由#

ベースのLlama 3モデルは一般的なタスクに優れていますが、以下を実現するにはファインチューニングが必要です:
  • 特定の関数呼び出しフォーマットに従う
  • ツール使用を確実に処理する
  • 一貫した出力スキーマを維持する
  • 幻覚的なツール呼び出しを減らす

前提条件#

  • 24GB以上のVRAMを搭載したGPU(A100 40GB推奨)
  • Python 3.10+
  • Llama 3へのアクセス権を持つHugging Faceアカウント

環境セットアップ#

bash
pip install torch transformers accelerate peft trl
pip install bitsandbytes datasets wandb
python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer

# GPUの確認
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

データセットの準備#

関数呼び出しフォーマット#

python
FUNCTION_CALLING_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>

あなたは以下のツールにアクセスできる役立つアシスタントです:

{tools}

ツールを使用する必要がある場合は、以下の形式で応答してください:
<tool_call>
{{"name": "tool_name", "arguments": {{"arg1": "value1"}}}}
</tool_call>
<|eot_id|><|start_header_id|>user<|end_header_id|>

{user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{assistant_response}<|eot_id|>"""

トレーニングデータの作成#

python
def create_training_example(tools: list, user_msg: str, assistant_response: str) -> dict:
    """関数呼び出しフォーマットでトレーニング例を作成します。"""
    tools_str = "\n".join([
        f"- {t['name']}: {t['description']}\n  パラメータ: {t['parameters']}"
        for t in tools
    ])

    text = FUNCTION_CALLING_TEMPLATE.format(
        tools=tools_str,
        user_message=user_msg,
        assistant_response=assistant_response
    )

    return {"text": text}

# データセットの例
training_data = [
    create_training_example(
        tools=[{
            "name": "get_weather",
            "description": "場所の現在の天気を取得します",
            "parameters": {"location": "string", "unit": "celsius|fahrenheit"}
        }],
        user_msg="東京の天気はどうですか?",
        assistant_response="""<tool_call>
{"name": "get_weather", "arguments": {"location": "Tokyo", "unit": "celsius"}}
</tool_call>"""
    )
]

QLoRAを使用したモデルの読み込み#

python
from transformers import BitsAndBytesConfig

# 4ビット量子化設定
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# モデルの読み込み
model_id = "meta-llama/Meta-Llama-3-70B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

LoRA設定#

python
# トレーニング用にモデルを準備
model = prepare_model_for_kbit_training(model)

# 関数呼び出し用のLoRA設定
lora_config = LoraConfig(
    r=64,  # ランク - 複雑なタスクには高く設定
    lora_alpha=128,
    lora_dropout=0.05,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    bias="none",
    task_type="CAUSAL_LM",
)

# LoRAを適用
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

トレーニング#

python
from transformers import TrainingArguments
from datasets import Dataset

# データセットの作成
dataset = Dataset.from_list(training_data)

# トレーニング引数
training_args = TrainingArguments(
    output_dir="./llama3-agent-ft",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=2e-4,
    weight_decay=0.01,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    logging_steps=10,
    save_strategy="epoch",
    bf16=True,
    gradient_checkpointing=True,
    optim="paged_adamw_32bit",
    report_to="wandb",
)

# トレーナーの初期化
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=training_args,
    tokenizer=tokenizer,
    max_seq_length=2048,
    dataset_text_field="text",
    packing=True,  # 短い例をまとめてパッキング
)

# トレーニング
trainer.train()

保存と読み込み#

python
# LoRA重みの保存
model.save_pretrained("./llama3-agent-lora")
tokenizer.save_pretrained("./llama3-agent-lora")

# マージして完全なモデルを保存(オプション)
merged_model = model.merge_and_unload()
merged_model.save_pretrained("./llama3-agent-merged")

推論#

python
from peft import PeftModel

# ファインチューニング済みモデルの読み込み
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
)
model = PeftModel.from_pretrained(base_model, "./llama3-agent-lora")

def generate_tool_call(user_message: str, tools: list) -> str:
    """ツール呼び出し応答を生成します。"""
    prompt = FUNCTION_CALLING_TEMPLATE.format(
        tools=format_tools(tools),
        user_message=user_message,
        assistant_response=""
    ).rsplit("<|start_header_id|>assistant<|end_header_id|>", 1)[0]
    prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.1,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )

    response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    return extract_assistant_response(response)

評価#

python
def evaluate_function_calling(model, test_cases: list) -> dict:
    """関数呼び出しの精度を評価します。"""
    results = {
        "correct_tool": 0,
        "correct_args": 0,
        "valid_json": 0,
        "total": len(test_cases)
    }

    for case in test_cases:
        response = generate_tool_call(case["input"], case["tools"])

        try:
            parsed = parse_tool_call(response)
            results["valid_json"] += 1

            if parsed["name"] == case["expected_tool"]:
                results["correct_tool"] += 1

            if parsed["arguments"] == case["expected_args"]:
                results["correct_args"] += 1
        except:
            pass

    return {k: v/results["total"] for k, v in results.items() if k != "total"}

ベストプラクティス#

  1. 多様なトレーニングデータ: 様々なツールの組み合わせやエッジケースを含める
  2. 一貫したフォーマット: トレーニングと推論で全く同じフォーマットを使用する
  3. ネガティブ例: ツールを呼び出すべきでない例を含める
  4. 検証: パースする前に常にJSON出力を検証する
  5. 温度: 信頼性の高いツール呼び出しには低い温度(0.1-0.3)を使用する