第5章补充材料:预训练进阶技术

1. 学习率调度:线性热身 + 余弦衰减 [MLNLP]

主章的训练循环用了固定学习率,但工业级预训练几乎都使用带热身的余弦衰减调度。MLNLP 项目的 ch05/04_learning_rate_schedulers 提供了完整实现。

为什么需要热身

训练初期,模型参数是随机初始化的,此时梯度方向不稳定、范数偏大。如果直接用大学习率,参数更新幅度过大,可能”跳出”好的优化区域。热身阶段(通常占总步数的 0.1%~2%)从小学习率线性增长到目标学习率,让模型先适应梯度景观。

余弦衰减公式

$$\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\frac{t - t_{\text{warmup}}}{T - t_{\text{warmup}}}\pi\right)$$

其中 $t_{\text{warmup}}$ 是热身步数,$T$ 是总步数,$\eta_{\max}$ 是峰值学习率,$\eta_{\min}$ 通常设为 $\eta_{\max} / 10$ 或 0。

完整实现

def get_lr(step, warmup_steps, max_steps, max_lr=3e-4, min_lr=3e-5):
    if step < warmup_steps:
        return max_lr * (step + 1) / warmup_steps
    if step >= max_steps:
        return min_lr
    progress = (step - warmup_steps) / (max_steps - warmup_steps)
    return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))

梯度裁剪

除了学习率调度,梯度裁剪也是防止训练爆炸的关键手段:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

按全局梯度范数裁剪到 max_norm。这不会改变梯度方向,只缩小过大的梯度。GPT-3、Llama 等大规模训练都使用 1.0 作为默认值。


2. 加载 GPT-2 权重:从 TensorFlow 到 PyTorch [MLNLP]

权重格式转换

OpenAI 公开的 GPT-2 权重是 TensorFlow checkpoint 格式。加载流程:

  1. 下载:从 OpenAI 的 Google Cloud Storage 下载 .h5 文件
  2. 读取:用 tensorflowh5py 读取权重数组
  3. 映射:将 TF 变量名映射到我们自定义的 PyTorch 模型

关键映射关系:

TensorFlow 名称PyTorch 对应
lm_head.weight / wtetok_emb.weight
wpepos_emb.weight
h.{i}.ln_1.{g,b}blocks.{i}.norm1.{shift,scale}
h.{i}.attn.{c_attn,c_proj}.{w,b}blocks.{i}.attn 中的 QKV 投影和输出投影
h.{i}.ln_2.{g,b}blocks.{i}.norm2.{shift,scale}
h.{i}.mlp.{c_fc,c_proj}.{w,b}blocks.{i}.ff 的两层线性

权重转置陷阱

TF 的 Dense 层权重形状为 (in, out),而 PyTorch 的 nn.Linear 权重形状为 (out, in)。加载时必须转置:

def load_tf_weights(tf_params, model):
    for name, param in model.named_parameters():
        tf_name = name_to_tf(name)
        arr = tf_params[tf_name]
        if arr.ndim == 2:
            arr = arr.T  # 关键:转置
        param.data = torch.from_numpy(arr)

替代权重来源 [MLNLP]

如果 OpenAI 的官方权重不可用,MLNLP 项目的 ch05/02_alternative_weight_loading 提供了从 Hugging Face Hub 加载的替代方案。Hugging Face 的 transformers 库封装了多种格式的 GPT-2 权重,可以通过 from_pretrained 一行加载。


3. 高效权重加载 [MLNLP]

当模型参数量较大时(如 1.5B 的 GPT-2 XL),直接 load_state_dict 会先在内存中创建一份完整的参数副本,导致峰值内存接近模型大小的 2 倍。

MLNLP 的 ch05/08_memory_efficient_weight_loading 展示了更高效的方法:

# 方法1:逐层加载,避免同时持有两份权重
state_dict = torch.load("weights.pt", map_location="cpu")
for name, param in model.named_parameters():
    param.data = state_dict[name]

# 方法2:使用 mmap 模式加载,不一次性读入内存
state_dict = torch.load("weights.pt", map_location="cpu", mmap=True)

方法 2(mmap 模式)在 PyTorch 2.1+ 可用,适合加载超大模型的检查点。


4. GPT → Llama 架构转换 [MLNLP]

这是 MLNLP 项目中最有价值的补充材料之一(ch05/07_gpt_to_llama),涵盖了从 GPT-2 到现代 Llama 3.2 架构的逐步演进。

架构演进路径

GPT-2 → Llama 2 → Llama 3/3.1 → Llama 3.2

每一步引入的关键改动:

4.1 GELU → SwiGLU

# GPT-2 的 FFN
class GPT_FFN:
    x → Linear(d, 4d) → GELU → Linear(4d, d)

# Llama 的 FFN(SwiGLU)
class Llama_FFN:
    x → Linear(d, 4d) → SiLU  ┐
    x → Linear(d, 4d)          ├→ ⊙ → Linear(4d, d)

SwiGLU 多了一个门控分支,用 SiLU($\text{SiLU}(x) = x \cdot \sigma(x)$)作为激活函数。

4.2 绝对位置编码 → RoPE(旋转位置编码)

GPT-2 使用可学习的绝对位置嵌入(nn.Embedding(max_len, d))。Llama 使用旋转位置编码(RoPE),它通过对 query 和 key 向量施加旋转矩阵来编码相对位置:

  • 不受最大序列长度限制(可外推)
  • 不增加额外参数
  • 相对位置关系直接体现在点积中

4.3 LayerNorm → RMSNorm

RMSNorm 是 LayerNorm 的简化版,去掉了均值中心化,只做缩放:

$$\text{RMSNorm}(x) = \frac{x}{\sqrt{\text{mean}(x^2) + \varepsilon}} \odot \gamma$$

省掉了均值计算,速度更快,且实验表明效果不逊于 LayerNorm。

4.4 多头注意力 → GQA(分组查询注意力)

Llama 2 引入了 GQA:多组 query 共享同一组 key/value,减少 KV 缓存的内存占用,加速推理。GPT-2 的每个 head 有独立的 Q/K/V,Llama 2 中每 $n$ 个 query head 共享一组 K/V。


5. 在 Gutenberg 语料库上预训练 [MLNLP]

MLNLP 的 ch05/03_bonus_pretraining_on_gutenberg 提供了使用 Project Gutenberg 全部书籍语料做更长时间预训练的代码。

核心要点:

  • Gutenberg 提供了约 7 万本公版书籍,总计约 10GB 文本
  • 用 124M GPT-2 在此语料上预训练数个 epoch,可以观察到 loss 持续下降
  • 主要用于教学目的——验证训练循环的正确性

6. 超参数调优 [MLNLP]

ch05/05_bonus_hparam_tuning 提供了自动超参数搜索脚本。关键参数的推荐范围:

超参数搜索范围推荐
学习率$1\text{e-}4 \sim 6\text{e-}4$$3\text{e-}4$
weight_decay$0.01 \sim 0.2$$0.1$
batch_size$4 \sim 32$尽量大,用梯度累积
warmup_steps总步数的 $0.1% \sim 2%$$0.5%$
gradient_clip$0.5 \sim 2.0$$1.0$

7. 扩展阅读 [MLNLP]

  • 学习率调度器ch05/04_learning_rate_schedulers 包含线性热身 + 余弦衰减的完整实现
  • GPT → Llama 转换系列ch05/07_gpt_to_llama/ 下的三个 notebook 逐步演示了架构演进
  • 交互式界面ch05/06_user_interface 提供了与预训练模型对话的 Gradio 界面
  • 附录 D:更详尽的训练循环优化策略(梯度累积、混合精度等)