找回密码
立即注册
搜索
热搜: Java Python Linux Go
发回帖 发新帖

214

积分

0

好友

22

主题
发表于 7 天前 | 查看: 20| 回复: 0

本教程将手把手指导你如何使用 Unsloth 框架和 TRL 库,通过 GRPO 强化学习算法对 Qwen 2.5 (3B) 大模型进行微调。整个过程聚焦实践,跳过复杂原理,直接提供可运行的完整代码。

一、环境安装与配置

首先,我们需要安装必要的依赖库。Unsloth 用于加速训练,vLLM 用于加速推理,TRL 则提供了强化学习训练的核心组件。

import os, numpy
# 设置环境变量,让 Unsloth 在 vLLM 中预留更多显存用于上下文
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
# 获取当前 numpy 版本以防止依赖冲突
numpy_version = f"numpy=={numpy.__version__}"
# 安装依赖
!uv pip install unsloth_zoo
!uv pip install --upgrade unsloth vllm==0.9.2 {numpy_version} torchvision bitsandbytes xformers
!uv pip install triton==3.2.0
!uv pip install transformers==4.55.4
!uv pip install --no-deps trl==0.22.2

通过高效的包管理工具 uvpip,我们可以快速搭建起基于 Python 的模型训练环境。

二、加载模型与分词器

环境就绪后,我们加载预训练的 Qwen2.5-3B-Instruct 模型。为了在有限显存下运行模型,这里采用了 4-bit 量化加载,并启用 vLLM 进行快速推理。

from unsloth import FastLanguageModel
import torch
# 设置最大上下文长度
max_seq_length = 1024
# 加载预训练模型和分词器
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen2.5-3B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True,               # 4bit 量化加载
    fast_inference = True,             # 启用 vLLM 快速推理引擎
    max_lora_rank = 8,                 # LoRA 秩
    gpu_memory_utilization = 0.9,      # 显存利用率上限
)

使用GRPO算法微调Qwen2.5模型以提升数学推理能力 - 图片 - 1

三、配置 LoRA (低秩适应)

接下来,我们将模型转换为 LoRA 模式。这是一种参数高效的微调方法,只训练新增的少量参数,而冻结原始模型的大部分参数,极大节省了计算资源。

# 配置 PEFT (Parameter-Efficient Fine-Tuning)
model = FastLanguageModel.get_peft_model(
    model,
    r = 8,  # LoRA 的秩
    # 指定需要应用 LoRA 的模块(注意力层和前馈网络层)
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = 8,
    use_gradient_checkpointing = "unsloth",       # 使用梯度检查点节省显存
    random_state = 1234,
)

使用GRPO算法微调Qwen2.5模型以提升数学推理能力 - 图片 - 2

四、数据集处理与格式化

我们使用 GSM8K(小学数学)数据集来训练模型的数学推理能力。为了让模型学会“思维链”推理,我们设计了一套包含特定 XML 标签的 Prompt 格式。

import re
from datasets import load_dataset, Dataset

# 系统提示词,强制模型使用特定的 XML 格式输出推理过程和答案
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

# 定义 XML 格式模板
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

# 函数:从模型输出中提取 XML 标签内的答案
def extract_xml_answer(text):
    if "<answer>" not in text or "</answer>" not in text:
        return ""
    return text.split("<answer>", 1)[-1].split("</answer>", 1)[0].strip()

# 函数:从 GSM8K 数据集的原始答案字段中提取最终数值
def extract_hash_answer(text):
    return text.split("####")[-1].strip() if "####" in text else None

# 加载并预处理 GSM8K 数据集
def get_gsm8k_dataset(split = "train"):
    data = load_dataset("openai/gsm8k", "main")[split]
    return data.map(
        lambda x: {
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": x["question"]},
            ],
            "answer": extract_hash_answer(x["answer"]),           # 提取标准答案用于后续奖励计算
        }
    )

# 加载处理好的数据集
dataset = get_gsm8k_dataset()

使用GRPO算法微调Qwen2.5模型以提升数学推理能力 - 图片 - 3

五、定义奖励函数

这是 GRPO 强化学习的核心。我们将定义一系列奖励函数来评估模型生成的回答,引导模型向“格式规范且答案正确”的方向优化。

# 奖励函数 1:正确性奖励
def correctness_reward_func(prompts, completions, answer, **kwargs):
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    # 打印日志方便调试
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

# 奖励函数 2:整数奖励
def int_reward_func(completions, **kwargs):
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

# 奖励函数 3:严格格式奖励
def strict_format_reward_func(completions, **kwargs):
    pattern = r"^\n<reasoning>.*?\n</reasoning>\n\n<answer>.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

# 奖励函数 4:宽松格式奖励
def soft_format_reward_func(completions, **kwargs):
    pattern = r".*?<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

# 辅助函数:计算 XML 标签的完整性
def count_xml(text):
    count = 0.0
    # 检查各个标签是否存在,每存在一个加分,如果格式混乱(如多余换行)则扣分
    if text.count("\n<reasoning>") == 1:
        count += 0.125
    if text.count("</reasoning>\n\n<answer>") == 1:
        count += 0.125
    if text.count("\n</answer>\n") == 1:
        count += 0.125
    count -= len(text.split("\n</answer>\n")[-1])*0.001       # 惩罚项
    if text.count("\n<reasoning>\n") == 1:
        count += 0.125
    count -= (len(text.split("\n<reasoning>\n")[-1]) - 1)*0.001   # 惩罚项
    return count

# 奖励函数 5:XML 计数奖励
def xmlcount_reward_func(completions, **kwargs):
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

使用GRPO算法微调Qwen2.5模型以提升数学推理能力 - 图片 - 4

六、配置与启动 GRPO 训练

现在,我们将模型、数据和奖励函数结合起来,使用 GRPO 算法进行训练。GRPO 的核心思想是让模型针对同一个问题生成多个回答,通过对比这些回答的奖励分数来优化策略。

from trl import GRPOConfig, GRPOTrainer

# 配置训练参数
training_args = GRPOConfig(
    use_vllm = True,                  # 使用 vLLM 生成样本(极快)
    learning_rate = 5e-6,             # 学习率
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",             # 使用 8-bit 优化器节省显存
    logging_steps = 1,
    per_device_train_batch_size = 4,
    gradient_accumulation_steps = 1,
    num_generations = 4,              # GRPO 核心:每个 prompt 生成 4 个回答进行对比
    max_prompt_length = 256,
    max_completion_length = 200,
    max_steps = 250,                  # 训练总步数
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none",
    output_dir = "outputs",
)

使用GRPO算法微调Qwen2.5模型以提升数学推理能力 - 图片 - 5

# 初始化 GRPO 训练器
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)

使用GRPO算法微调Qwen2.5模型以提升数学推理能力 - 图片 - 6

# 开始训练
trainer.train()

训练过程中,模型会根据奖励函数的反馈更新 LoRA 权重,使其更倾向于生成高分回答(格式正确且答案正确)。这种基于 人工智能 强化学习的微调方式,能有效引导模型掌握复杂的推理任务。

七、保存模型与推理测试

训练完成后,我们保存微调得到的 LoRA 权重,并进行推理测试以验证效果。

# 保存训练好的 LoRA 适配器
model.save_lora("grpo_saved_lora")
# --- 推理部分 ---
from vllm import SamplingParams
# 测试用的查询
query = "How many r's are in strawberry?"
# 构建聊天模板
text = tokenizer.apply_chat_template([
    {"role" : "user", "content" : query},
], tokenize = False, add_generation_prompt = True)
# 设置采样参数
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
# 生成回答(使用基础模型,未加载LoRA)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text
print(output)

使用GRPO算法微调Qwen2.5模型以提升数学推理能力 - 图片 - 7

# 再次生成,这次加载刚才保存的 LoRA 权重
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : query},
], tokenize = False, add_generation_prompt = True)

sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)

output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text
print(output)

使用GRPO算法微调Qwen2.5模型以提升数学推理能力 - 图片 - 8
通过对比加载 LoRA 权重前后的输出,可以直观地看到模型经过 GRPO 强化学习 微调后,在遵循指定格式和进行有效推理方面的能力提升。


源代码地址:完整的可运行代码已托管在 GitHub: https://github.com/ArronAI007/Awesome-AGI/blob/main/LLM%20Pipeline/Fine-Tune/trl/01_Train_Qwen_2_5(3B)_To_Reason_With_GRPO.ipynb




上一篇:微软开源VibeVoice:实现90分钟4角色自然对话的长语音合成模型
下一篇:Qt跨平台开发指南:QString与const char*转换详解与中文乱码解决
您需要登录后才可以回帖 登录 | 立即注册

手机版|小黑屋|网站地图|云栈社区 ( 苏ICP备2022046150号-2 )

GMT+8, 2025-12-24 21:10 , Processed in 0.288135 second(s), 40 queries , Gzip On.

Powered by Discuz! X3.5

© 2025-2025 云栈社区.

快速回复 返回顶部 返回列表