本教程将手把手指导你如何使用 Unsloth 框架和 TRL 库,通过 GRPO 强化学习算法对 Qwen 2.5 (3B) 大模型进行微调。整个过程聚焦实践,跳过复杂原理,直接提供可运行的完整代码。
一、环境安装与配置
首先,我们需要安装必要的依赖库。Unsloth 用于加速训练,vLLM 用于加速推理,TRL 则提供了强化学习训练的核心组件。
import os, numpy
# 设置环境变量,让 Unsloth 在 vLLM 中预留更多显存用于上下文
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
# 获取当前 numpy 版本以防止依赖冲突
numpy_version = f"numpy=={numpy.__version__}"
# 安装依赖
!uv pip install unsloth_zoo
!uv pip install --upgrade unsloth vllm==0.9.2 {numpy_version} torchvision bitsandbytes xformers
!uv pip install triton==3.2.0
!uv pip install transformers==4.55.4
!uv pip install --no-deps trl==0.22.2
通过高效的包管理工具 uv 和 pip,我们可以快速搭建起基于 Python 的模型训练环境。
二、加载模型与分词器
环境就绪后,我们加载预训练的 Qwen2.5-3B-Instruct 模型。为了在有限显存下运行模型,这里采用了 4-bit 量化加载,并启用 vLLM 进行快速推理。
from unsloth import FastLanguageModel
import torch
# 设置最大上下文长度
max_seq_length = 1024
# 加载预训练模型和分词器
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Qwen2.5-3B-Instruct",
max_seq_length = max_seq_length,
load_in_4bit = True, # 4bit 量化加载
fast_inference = True, # 启用 vLLM 快速推理引擎
max_lora_rank = 8, # LoRA 秩
gpu_memory_utilization = 0.9, # 显存利用率上限
)

三、配置 LoRA (低秩适应)
接下来,我们将模型转换为 LoRA 模式。这是一种参数高效的微调方法,只训练新增的少量参数,而冻结原始模型的大部分参数,极大节省了计算资源。
# 配置 PEFT (Parameter-Efficient Fine-Tuning)
model = FastLanguageModel.get_peft_model(
model,
r = 8, # LoRA 的秩
# 指定需要应用 LoRA 的模块(注意力层和前馈网络层)
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha = 8,
use_gradient_checkpointing = "unsloth", # 使用梯度检查点节省显存
random_state = 1234,
)

四、数据集处理与格式化
我们使用 GSM8K(小学数学)数据集来训练模型的数学推理能力。为了让模型学会“思维链”推理,我们设计了一套包含特定 XML 标签的 Prompt 格式。
import re
from datasets import load_dataset, Dataset
# 系统提示词,强制模型使用特定的 XML 格式输出推理过程和答案
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
# 定义 XML 格式模板
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
# 函数:从模型输出中提取 XML 标签内的答案
def extract_xml_answer(text):
if "<answer>" not in text or "</answer>" not in text:
return ""
return text.split("<answer>", 1)[-1].split("</answer>", 1)[0].strip()
# 函数:从 GSM8K 数据集的原始答案字段中提取最终数值
def extract_hash_answer(text):
return text.split("####")[-1].strip() if "####" in text else None
# 加载并预处理 GSM8K 数据集
def get_gsm8k_dataset(split = "train"):
data = load_dataset("openai/gsm8k", "main")[split]
return data.map(
lambda x: {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": x["question"]},
],
"answer": extract_hash_answer(x["answer"]), # 提取标准答案用于后续奖励计算
}
)
# 加载处理好的数据集
dataset = get_gsm8k_dataset()

五、定义奖励函数
这是 GRPO 强化学习的核心。我们将定义一系列奖励函数来评估模型生成的回答,引导模型向“格式规范且答案正确”的方向优化。
# 奖励函数 1:正确性奖励
def correctness_reward_func(prompts, completions, answer, **kwargs):
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
# 打印日志方便调试
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
# 奖励函数 2:整数奖励
def int_reward_func(completions, **kwargs):
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
# 奖励函数 3:严格格式奖励
def strict_format_reward_func(completions, **kwargs):
pattern = r"^\n<reasoning>.*?\n</reasoning>\n\n<answer>.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
# 奖励函数 4:宽松格式奖励
def soft_format_reward_func(completions, **kwargs):
pattern = r".*?<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
# 辅助函数:计算 XML 标签的完整性
def count_xml(text):
count = 0.0
# 检查各个标签是否存在,每存在一个加分,如果格式混乱(如多余换行)则扣分
if text.count("\n<reasoning>") == 1:
count += 0.125
if text.count("</reasoning>\n\n<answer>") == 1:
count += 0.125
if text.count("\n</answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001 # 惩罚项
if text.count("\n<reasoning>\n") == 1:
count += 0.125
count -= (len(text.split("\n<reasoning>\n")[-1]) - 1)*0.001 # 惩罚项
return count
# 奖励函数 5:XML 计数奖励
def xmlcount_reward_func(completions, **kwargs):
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]

六、配置与启动 GRPO 训练
现在,我们将模型、数据和奖励函数结合起来,使用 GRPO 算法进行训练。GRPO 的核心思想是让模型针对同一个问题生成多个回答,通过对比这些回答的奖励分数来优化策略。
from trl import GRPOConfig, GRPOTrainer
# 配置训练参数
training_args = GRPOConfig(
use_vllm = True, # 使用 vLLM 生成样本(极快)
learning_rate = 5e-6, # 学习率
adam_beta1 = 0.9,
adam_beta2 = 0.99,
weight_decay = 0.1,
warmup_ratio = 0.1,
lr_scheduler_type = "cosine",
optim = "adamw_8bit", # 使用 8-bit 优化器节省显存
logging_steps = 1,
per_device_train_batch_size = 4,
gradient_accumulation_steps = 1,
num_generations = 4, # GRPO 核心:每个 prompt 生成 4 个回答进行对比
max_prompt_length = 256,
max_completion_length = 200,
max_steps = 250, # 训练总步数
save_steps = 250,
max_grad_norm = 0.1,
report_to = "none",
output_dir = "outputs",
)

# 初始化 GRPO 训练器
trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
reward_funcs = [
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func,
],
args = training_args,
train_dataset = dataset,
)

# 开始训练
trainer.train()
训练过程中,模型会根据奖励函数的反馈更新 LoRA 权重,使其更倾向于生成高分回答(格式正确且答案正确)。这种基于 人工智能 强化学习的微调方式,能有效引导模型掌握复杂的推理任务。
七、保存模型与推理测试
训练完成后,我们保存微调得到的 LoRA 权重,并进行推理测试以验证效果。
# 保存训练好的 LoRA 适配器
model.save_lora("grpo_saved_lora")
# --- 推理部分 ---
from vllm import SamplingParams
# 测试用的查询
query = "How many r's are in strawberry?"
# 构建聊天模板
text = tokenizer.apply_chat_template([
{"role" : "user", "content" : query},
], tokenize = False, add_generation_prompt = True)
# 设置采样参数
sampling_params = SamplingParams(
temperature = 0.8,
top_p = 0.95,
max_tokens = 1024,
)
# 生成回答(使用基础模型,未加载LoRA)
output = model.fast_generate(
[text],
sampling_params = sampling_params,
lora_request = None,
)[0].outputs[0].text
print(output)

# 再次生成,这次加载刚才保存的 LoRA 权重
text = tokenizer.apply_chat_template([
{"role" : "system", "content" : SYSTEM_PROMPT},
{"role" : "user", "content" : query},
], tokenize = False, add_generation_prompt = True)
sampling_params = SamplingParams(
temperature = 0.8,
top_p = 0.95,
max_tokens = 1024,
)
output = model.fast_generate(
[text],
sampling_params = sampling_params,
lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text
print(output)

通过对比加载 LoRA 权重前后的输出,可以直观地看到模型经过 GRPO 强化学习 微调后,在遵循指定格式和进行有效推理方面的能力提升。
源代码地址:完整的可运行代码已托管在 GitHub: https://github.com/ArronAI007/Awesome-AGI/blob/main/LLM%20Pipeline/Fine-Tune/trl/01_Train_Qwen_2_5(3B)_To_Reason_With_GRPO.ipynb