附录D:训练循环增强

把”能跑”的训练循环升级成”能跑得稳、跑得快、跑得好”。本章覆盖三个标准技巧:学习率热身 + 余弦衰减、梯度裁剪、参数分组的 weight decay。

D.1 学习率热身 + 余弦衰减

直接用大学习率从 step 0 开始训练,前几百步极容易”炸”——loss 先降一点然后突然爆涨到 NaN。标准应对:

  1. 热身(warmup):前 n_warmup 步把学习率从 0 线性升到目标值;
  2. 余弦衰减:之后按余弦曲线缓慢降到一个最小值(通常是峰值的 10%)。
import math

def lr_at(step, peak_lr, n_warmup, n_total, min_ratio=0.1):
    if step < n_warmup:
        return peak_lr * step / max(1, n_warmup)
    progress = (step - n_warmup) / max(1, n_total - n_warmup)
    progress = min(1.0, progress)
    cosine = 0.5 * (1 + math.cos(math.pi * progress))
    return peak_lr * (min_ratio + (1 - min_ratio) * cosine)

在训练循环里:

for step, (x, y) in enumerate(loader):
    lr = lr_at(step, peak_lr=3e-4, n_warmup=500, n_total=20_000)
    for g in optimizer.param_groups:
        g["lr"] = lr
    ...

经验:热身步数大约取总步数的 1%3%,或固定 2002000 步。

D.2 梯度裁剪

少数 batch 的梯度异常大(数值溢出、坏样本等)会让参数被一脚踹飞。全局梯度范数裁剪是最简单也最有效的防御:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

放在 loss.backward() 之后、optimizer.step() 之前。max_norm=1.0 是 GPT 系列的常用值。

D.3 weight decay 的正确分组

AdamW 的 weight decay 是对所有参数都生效的。但不应该给以下参数加 decay:

  • 所有 LayerNormscaleshift
  • 所有 bias
  • 所有 Embedding 表(这点有争议,本课程不衰减它们)。

否则模型会在初期被无谓地往 0 拉,训练变慢。

def make_param_groups(model, weight_decay=0.1):
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad: continue
        if p.ndim < 2 or name.endswith(".bias") or "norm" in name.lower() or "emb" in name.lower():
            no_decay.append(p)
        else:
            decay.append(p)
    return [
        {"params": decay,    "weight_decay": weight_decay},
        {"params": no_decay, "weight_decay": 0.0},
    ]

optimizer = torch.optim.AdamW(make_param_groups(model, 0.1), lr=3e-4, betas=(0.9, 0.95))

D.4 梯度累积

显存不够装下你想要的”等效 batch”时,把一个大 batch 切成 N 份,前向反向 N 次,最后只调用一次 optimizer.step()

ACCUM = 8
for step, (x, y) in enumerate(loader):
    x, y = x.to(device), y.to(device)
    loss = compute_loss(model, x, y) / ACCUM
    loss.backward()
    if (step + 1) % ACCUM == 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()

注意 loss / ACCUM——这是为了让平均梯度的尺度和”真正的大 batch”一致。

D.5 检查点 (Checkpointing)

长时间训练务必每隔 N 步存一次:

if step % 1000 == 0:
    torch.save({
        "step": step,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }, f"ckpt_{step}.pt")

恢复时记得把 optimizer.state_dict() 也加载回来——否则 Adam 的一阶/二阶动量重置,loss 会”回弹”几百步才稳定。


← 附录 A · 返回目录 · 附录 E · LoRA →