一、开篇:站在巨人的肩膀上
过去两天,我们探讨了两个里程碑式的模型:
- BERT:擅长理解任务的编码器模型。
- GPT:擅长生成任务的解码器模型。
它们都在海量数据上完成了预训练,掌握了通用的语言知识。但问题来了:如果我们想在自己的特定任务上,比如判断某条微博的情绪倾向,或者让它理解某个领域的专业术语,该怎么办呢?
答案就是 微调(Fine-tuning)。
微调 = 在预训练模型的基础上,使用你自己的少量标注数据进行继续训练,让模型适应你的特定任务。
与从零训练相比,微调的优势非常明显:
- 数据需求少:几百到几千条标注数据往往就足够了。
- 训练速度快:通常只需几十分钟到几小时。
- 效果出众:效果远好于在少量数据上从零开始训练。
- 参数调整少:通常只需调整模型的顶层或少量参数。
二、微调的核心原理
2.1 预训练模型学到了什么?
像 BERT 这样的预训练模型,在数十亿词的语料上训练后,其网络的不同层级学会了不同抽象层次的语言知识:
- 底层:负责词法、词性等基础信息。
- 中层:学习句法、短语结构等。
- 高层:捕获语义、上下文关联等复杂关系。
这些知识是通用的,为各种 NLP 任务提供了强大的基础。
2.2 微调时发生了什么?
当我们加载预训练模型时,其参数已经包含了丰富的语言先验。微调的过程可以概括为三步:
- 替换任务头:移除预训练任务对应的顶层(例如 BERT 的 MLM 层),换上适合我们下游任务的输出层(比如一个分类器)。
- 继续反向传播:使用我们自己的标注数据计算损失,梯度会反向传播并更新模型所有层(或指定层)的参数。
- 模型适应任务:在保留通用语言知识的同时,模型参数被微调,使其在特定任务上的表现得到提升。
2.3 两种微调策略
| 策略 |
做法 |
适用场景 |
| 全参数微调 |
更新模型中的所有参数 |
数据量充足(>1000条),计算资源丰富 |
| 参数高效微调 |
冻结大部分预训练层,只更新顶层或少量新增参数 |
数据极少,计算资源有限,或需要快速实验迭代 |
参数高效微调是目前的热门方向,典型方法包括:
- 冻结 BERT,只训练分类头:最简单的方法,适用于小数据场景。
- Adapter:在 Transformer 层之间插入轻量级的小型模块进行训练。
- LoRA:使用低秩矩阵来近似参数更新,在 GPT-3 等大模型微调中非常流行。
三、实战:微调 BERT 进行中文情感分类
我们将使用 HuggingFace 的 Transformers 库,在 ChnSentiCorp(中文酒店评论情感分类)数据集上,实战微调一个 BERT 模型。
3.1 环境准备
首先,安装必要的库。
pip install transformers datasets torch pandas scikit-learn
3.2 加载数据集
使用 datasets 库直接加载 ChnSentiCorp 数据集。
from datasets import load_dataset
dataset = load_dataset("seamew/ChnSentiCorp")
print(dataset)
# 输出:
# DatasetDict({
# train: Dataset({ features: ['text', 'label'], num_rows: 9600 })
# validation: Dataset({ features: ['text', 'label'], num_rows: 1200 })
# test: Dataset({ features: ['text', 'label'], num_rows: 1200 })
# })
数据预览:
text: 酒店评论文本。
label: 情感标签,0 表示负向,1 表示正向。
3.3 数据预处理
我们需要将文本转换成 BERT 能接受的输入格式:input_ids 和 attention_mask。
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
def preprocess_function(examples):
return tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=128
)
tokenized_dataset = dataset.map(preprocess_function, batched=True)
# 设置数据格式
tokenized_dataset = tokenized_dataset.remove_columns(["text"])
tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
3.4 加载预训练模型
加载一个专门用于序列分类的 BERT 模型(它会自动加上分类头)。
from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained(
"bert-base-chinese",
num_labels=2
)
3.5 定义评估指标
我们使用准确率作为评估指标。
import numpy as np
from sklearn.metrics import accuracy_score
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return {"accuracy": accuracy_score(labels, predictions)}
3.6 设置训练参数
通过 TrainingArguments 来配置训练过程的关键 超参数。
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
logging_dir="./logs",
)
3.7 创建 Trainer 并训练
Trainer 类封装了训练循环,让我们可以专注于数据和模型。
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
trainer.train()
3.8 评估与保存
训练完成后,在测试集上评估模型性能并保存模型。
# 在测试集上评估
results = trainer.evaluate(tokenized_dataset["test"])
print("测试集准确率:", results["eval_accuracy"])
# 保存模型
model.save_pretrained("./my_finetuned_bert")
tokenizer.save_pretrained("./my_finetuned_bert")
3.9 使用微调后的模型进行预测
最后,我们写一个简单的函数来使用微调好的模型进行预测。
def predict(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
pred = torch.argmax(probs, dim=-1).item()
return "正向" if pred == 1 else "负向", probs[0].tolist()
text = "房间很干净,服务态度也很好,下次还会来。"
sentiment, prob = predict(text)
print(f"文本:{text}")
print(f"情感:{sentiment}, 概率:{prob}")
四、实战:微调 GPT-2 进行文本生成
我们也可以在生成任务上微调模型。例如,让 GPT-2 学会创作特定风格的诗歌。
4.1 加载 GPT-2 模型
from transformers import GPT2LMHeadModel, GPT2Tokenizer
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# 设置 pad_token_id(GPT-2 原始没有 pad_token)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
4.2 准备数据
假设我们有一个 poems.txt 文件,每行是一首诗歌。
from datasets import Dataset
with open("poems.txt", "r", encoding="utf-8") as f:
poems = [line.strip() for line in f if line.strip()]
dataset = Dataset.from_dict({"text": poems})
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True, padding=False, max_length=128)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset = tokenized_dataset.map(
lambda examples: {"labels": examples["input_ids"]},
batched=True
)
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
4.3 微调
对于语言模型,我们需要使用 DataCollatorForLanguageModeling。
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False # GPT 是因果语言模型,不是掩码语言模型
)
training_args = TrainingArguments(
output_dir="./gpt2-poems",
overwrite_output_dir=True,
num_train_epochs=3,
per_device_train_batch_size=4,
save_steps=500,
save_total_limit=2,
prediction_loss_only=True,
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=tokenized_dataset,
)
trainer.train()
model.save_pretrained("./gpt2-poems")
tokenizer.save_pretrained("./gpt2-poems")
4.4 生成诗歌
使用微调后的模型进行诗歌续写。
prompt = "床前明月光,"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
inputs.input_ids,
max_length=50,
do_sample=True,
top_p=0.9,
temperature=0.8,
pad_token_id=tokenizer.eos_token_id
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
总结:微调——让通用模型变成私人定制
微调是连接强大的通用预训练模型与具体下游任务的关键桥梁。它让我们能够以极低的成本和数据门槛,获得专业级的模型性能。
通过本文的实战,你已经掌握了:
- 微调的核心原理与两种主要策略(全参数与参数高效)。
- 使用 BERT 微调完成中文情感分类任务的全流程。
- 使用 GPT-2 微调进行特定风格文本生成的方法。
- 实际操作中的关键技巧与注意事项。
希望这篇结合原理与实战的指南能帮助你快速上手模型微调。如果在实践中遇到问题,欢迎到 云栈社区 与更多开发者交流讨论。