第6章:分类微调
学习目标
- 理解为什么”加载预训练 GPT 再做分类”比从零训练分类器更划算。
- 学会把生成式 GPT 的输出头改造成分类头。
- 掌握”冻结大部分参数 + 只训练顶部几层”的微调策略。
- 会用 accuracy / F1 评估分类模型。
6.1 任务设定
本章以经典的 SMS 垃圾短信分类为例:输入一条短信,输出 spam 或 ham(非垃圾)。这是一个二分类任务,数据规模在几千条量级,非常适合演示微调。
为什么不直接用 BERT?——课程主线是 GPT 路线,我们要验证 Decoder-only 模型同样能做分类。事实上,从 GPT-3 开始,“用生成模型做分类”已经成为主流路径。
6.2 数据准备
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import tiktoken
class SpamDataset(Dataset):
def __init__(self, csv_path, tokenizer, max_len=128, pad_token_id=50256):
df = pd.read_csv(csv_path) # 列: ['Label', 'Text']
self.label2id = {"ham": 0, "spam": 1}
self.max_len = max_len
self.pad_id = pad_token_id
self.encoded = []
self.labels = []
for _, row in df.iterrows():
ids = tokenizer.encode(row["Text"])
ids = ids[:max_len] # 截断
ids += [pad_token_id] * (max_len - len(ids)) # 右侧 pad
self.encoded.append(torch.tensor(ids))
self.labels.append(self.label2id[row["Label"]])
def __len__(self):
return len(self.encoded)
def __getitem__(self, idx):
return self.encoded[idx], torch.tensor(self.labels[idx])
注意:GPT-2 没有专门的 PAD token,习惯上借用 <|endoftext|>(id=50256)做填充。需要在注意力 mask 或 loss 计算时处理 padding 位置——分类任务里我们只取最后一个非 pad位置的隐藏向量,所以问题不大。
6.3 改造模型:替换输出头
预训练 GPT 的输出头是 Linear(emb_dim, vocab_size),把隐向量映射到词表 logits。分类任务下,我们不需要 50257 个输出,只需要 2 个:
def attach_classifier_head(gpt_model, num_classes=2):
emb_dim = gpt_model.out_head.in_features
gpt_model.out_head = torch.nn.Linear(emb_dim, num_classes)
return gpt_model
前向逻辑也要小改,只取最后一个有效 token 的隐藏状态作为整段文本的表示:
import torch.nn as nn
class GPTClassifier(nn.Module):
def __init__(self, gpt_model, num_classes=2):
super().__init__()
self.gpt = gpt_model
emb_dim = gpt_model.out_head.in_features
self.gpt.out_head = nn.Linear(emb_dim, num_classes)
def forward(self, idx): # (B, T)
# 复用 GPT 的前向,但拿到的是 (B, T, num_classes)
logits = self.gpt(idx)
return logits[:, -1, :] # 取最后一个时间步: (B, num_classes)
这里”最后一个位置”假设我们对齐方式是右侧 pad,最后一个真实 token 在 pad 之前。更严谨的做法是记录每条样本的真实长度并按位取出,留作练习。
6.4 冻结策略
完整反向传播 GPT-2 124M 的所有参数当然能拿到最好效果,但显存需求高、容易过拟合在小数据上。常用三档策略:
| 策略 | 训练参数 | 适用场景 |
|---|---|---|
| 全微调 | 全部 | 数据多、显存足 |
| 部分微调 | 仅最后 N 个 Block + 输出头 + 末端 LayerNorm | 中等数据,平衡之选 |
| 特征提取 | 仅输出头 | 数据少、把 GPT 当编码器 |
实现”部分微调”的代码:
def freeze_except_top(model, n_train_blocks=1):
for p in model.parameters():
p.requires_grad = False
# 解冻最后 n_train_blocks 个 Transformer Block
for blk in model.gpt.blocks[-n_train_blocks:]:
for p in blk.parameters():
p.requires_grad = True
# 解冻最终 LayerNorm 和分类头
for p in model.gpt.final_norm.parameters():
p.requires_grad = True
for p in model.gpt.out_head.parameters():
p.requires_grad = True
6.5 训练循环(分类版)
和预训练几乎一样,只是损失换成”对最后一个位置做交叉熵”:
import torch.nn.functional as F
def loss_cls(model, x, y):
logits = model(x) # (B, num_classes)
return F.cross_entropy(logits, y)
@torch.no_grad()
def accuracy(model, loader, device):
model.eval()
correct, total = 0, 0
for x, y in loader:
x, y = x.to(device), y.to(device)
pred = model(x).argmax(dim=-1)
correct += (pred == y).sum().item()
total += y.numel()
model.train()
return correct / total
经典超参数(参考值):
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=5e-5, weight_decay=0.1,
)
num_epochs = 5
batch_size = 8
在 SMS Spam 数据上,仅微调最后 1 个 Block + 输出头通常能在几个 epoch 内拿到 95%+ 的准确率。
6.6 进一步评估
只看准确率会被类别不平衡欺骗(比如 ham 占 87%,全猜 ham 就有 87% 准确率)。建议同时报告:
- 混淆矩阵 (
sklearn.metrics.confusion_matrix) - 精确率 / 召回率 / F1(特别是 spam 类的 F1)
- ROC-AUC(如果模型输出概率)
from sklearn.metrics import classification_report, confusion_matrix
@torch.no_grad()
def report(model, loader, device):
model.eval()
y_true, y_pred = [], []
for x, y in loader:
x = x.to(device)
p = model(x).argmax(dim=-1).cpu()
y_true.extend(y.tolist()); y_pred.extend(p.tolist())
print(confusion_matrix(y_true, y_pred))
print(classification_report(y_true, y_pred, target_names=["ham", "spam"]))
6.7 经验建议
- 学习率:分类微调通常比预训练小 1~2 个数量级,5e-5 是个安全起点。
- 数据划分:训练 / 验证 / 测试 70/10/20,按类别分层抽样,避免类别漂移。
- 早停:在小数据上极易过拟合,验证集 loss 连续 2 个 epoch 不降就停。
- 不平衡处理:可在
cross_entropy里传weight=给少数类加权。
检查清单
- 我能解释为什么取”最后一个 token”的隐藏向量作为分类特征。
- 我能写出”全微调 / 部分微调 / 特征提取”三种代码切换。
- 我会在不平衡数据上同时看 accuracy 和 F1。
练习题
- 把”取最后一个时间步”改成”取最后一个非 pad 时间步”,需要记录每条样本的真实长度。写一版兼容的 forward。
- 把分类头从
Linear换成Linear → GELU → Linear的两层 MLP,看看是否提升。 - 用同一份数据训练一个 BERT-base 分类器作对照,比较两者在准确率与训练耗时上的差异。
📖 第6章补充材料 → — 冻结策略实验、IMDB五模型横向评测、分组权重衰减
← 上一章 · 返回目录 · 下一章 · 指令微调 →