第6章:分类微调

学习目标

  • 理解为什么”加载预训练 GPT 再做分类”比从零训练分类器更划算。
  • 学会把生成式 GPT 的输出头改造成分类头。
  • 掌握”冻结大部分参数 + 只训练顶部几层”的微调策略。
  • 会用 accuracy / F1 评估分类模型。

6.1 任务设定

本章以经典的 SMS 垃圾短信分类为例:输入一条短信,输出 spamham(非垃圾)。这是一个二分类任务,数据规模在几千条量级,非常适合演示微调。

为什么不直接用 BERT?——课程主线是 GPT 路线,我们要验证 Decoder-only 模型同样能做分类。事实上,从 GPT-3 开始,“用生成模型做分类”已经成为主流路径。

6.2 数据准备

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import tiktoken

class SpamDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_len=128, pad_token_id=50256):
        df = pd.read_csv(csv_path)             # 列: ['Label', 'Text']
        self.label2id = {"ham": 0, "spam": 1}
        self.max_len = max_len
        self.pad_id = pad_token_id

        self.encoded = []
        self.labels  = []
        for _, row in df.iterrows():
            ids = tokenizer.encode(row["Text"])
            ids = ids[:max_len]                       # 截断
            ids += [pad_token_id] * (max_len - len(ids))  # 右侧 pad
            self.encoded.append(torch.tensor(ids))
            self.labels.append(self.label2id[row["Label"]])

    def __len__(self):
        return len(self.encoded)

    def __getitem__(self, idx):
        return self.encoded[idx], torch.tensor(self.labels[idx])

注意:GPT-2 没有专门的 PAD token,习惯上借用 <|endoftext|>(id=50256)做填充。需要在注意力 mask 或 loss 计算时处理 padding 位置——分类任务里我们只取最后一个非 pad位置的隐藏向量,所以问题不大。

6.3 改造模型:替换输出头

预训练 GPT 的输出头是 Linear(emb_dim, vocab_size),把隐向量映射到词表 logits。分类任务下,我们不需要 50257 个输出,只需要 2 个:

def attach_classifier_head(gpt_model, num_classes=2):
    emb_dim = gpt_model.out_head.in_features
    gpt_model.out_head = torch.nn.Linear(emb_dim, num_classes)
    return gpt_model

前向逻辑也要小改,只取最后一个有效 token 的隐藏状态作为整段文本的表示

import torch.nn as nn

class GPTClassifier(nn.Module):
    def __init__(self, gpt_model, num_classes=2):
        super().__init__()
        self.gpt = gpt_model
        emb_dim = gpt_model.out_head.in_features
        self.gpt.out_head = nn.Linear(emb_dim, num_classes)

    def forward(self, idx):                       # (B, T)
        # 复用 GPT 的前向,但拿到的是 (B, T, num_classes)
        logits = self.gpt(idx)
        return logits[:, -1, :]                   # 取最后一个时间步: (B, num_classes)

这里”最后一个位置”假设我们对齐方式是右侧 pad,最后一个真实 token 在 pad 之前。更严谨的做法是记录每条样本的真实长度并按位取出,留作练习。

6.4 冻结策略

完整反向传播 GPT-2 124M 的所有参数当然能拿到最好效果,但显存需求高、容易过拟合在小数据上。常用三档策略:

策略训练参数适用场景
全微调全部数据多、显存足
部分微调仅最后 N 个 Block + 输出头 + 末端 LayerNorm中等数据,平衡之选
特征提取仅输出头数据少、把 GPT 当编码器

实现”部分微调”的代码:

def freeze_except_top(model, n_train_blocks=1):
    for p in model.parameters():
        p.requires_grad = False

    # 解冻最后 n_train_blocks 个 Transformer Block
    for blk in model.gpt.blocks[-n_train_blocks:]:
        for p in blk.parameters():
            p.requires_grad = True
    # 解冻最终 LayerNorm 和分类头
    for p in model.gpt.final_norm.parameters():
        p.requires_grad = True
    for p in model.gpt.out_head.parameters():
        p.requires_grad = True

6.5 训练循环(分类版)

和预训练几乎一样,只是损失换成”对最后一个位置做交叉熵”:

import torch.nn.functional as F

def loss_cls(model, x, y):
    logits = model(x)                  # (B, num_classes)
    return F.cross_entropy(logits, y)

@torch.no_grad()
def accuracy(model, loader, device):
    model.eval()
    correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        pred = model(x).argmax(dim=-1)
        correct += (pred == y).sum().item()
        total   += y.numel()
    model.train()
    return correct / total

经典超参数(参考值):

optimizer = torch.optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=5e-5, weight_decay=0.1,
)
num_epochs = 5
batch_size = 8

在 SMS Spam 数据上,仅微调最后 1 个 Block + 输出头通常能在几个 epoch 内拿到 95%+ 的准确率。

6.6 进一步评估

只看准确率会被类别不平衡欺骗(比如 ham 占 87%,全猜 ham 就有 87% 准确率)。建议同时报告:

  • 混淆矩阵 (sklearn.metrics.confusion_matrix)
  • 精确率 / 召回率 / F1(特别是 spam 类的 F1)
  • ROC-AUC(如果模型输出概率)
from sklearn.metrics import classification_report, confusion_matrix

@torch.no_grad()
def report(model, loader, device):
    model.eval()
    y_true, y_pred = [], []
    for x, y in loader:
        x = x.to(device)
        p = model(x).argmax(dim=-1).cpu()
        y_true.extend(y.tolist()); y_pred.extend(p.tolist())
    print(confusion_matrix(y_true, y_pred))
    print(classification_report(y_true, y_pred, target_names=["ham", "spam"]))

6.7 经验建议

  1. 学习率:分类微调通常比预训练小 1~2 个数量级,5e-5 是个安全起点。
  2. 数据划分:训练 / 验证 / 测试 70/10/20,按类别分层抽样,避免类别漂移。
  3. 早停:在小数据上极易过拟合,验证集 loss 连续 2 个 epoch 不降就停。
  4. 不平衡处理:可在 cross_entropy 里传 weight= 给少数类加权。

检查清单

  • 我能解释为什么取”最后一个 token”的隐藏向量作为分类特征。
  • 我能写出”全微调 / 部分微调 / 特征提取”三种代码切换。
  • 我会在不平衡数据上同时看 accuracy 和 F1。

练习题

  1. 把”取最后一个时间步”改成”取最后一个非 pad 时间步”,需要记录每条样本的真实长度。写一版兼容的 forward。
  2. 把分类头从 Linear 换成 Linear → GELU → Linear 的两层 MLP,看看是否提升。
  3. 用同一份数据训练一个 BERT-base 分类器作对照,比较两者在准确率与训练耗时上的差异。

📖 第6章补充材料 → — 冻结策略实验、IMDB五模型横向评测、分组权重衰减


← 上一章 · 返回目录 · 下一章 · 指令微调 →