第5章补充材料:预训练进阶技术
1. 学习率调度:线性热身 + 余弦衰减 [MLNLP]
主章的训练循环用了固定学习率,但工业级预训练几乎都使用带热身的余弦衰减调度。MLNLP 项目的 ch05/04_learning_rate_schedulers 提供了完整实现。
为什么需要热身
训练初期,模型参数是随机初始化的,此时梯度方向不稳定、范数偏大。如果直接用大学习率,参数更新幅度过大,可能”跳出”好的优化区域。热身阶段(通常占总步数的 0.1%~2%)从小学习率线性增长到目标学习率,让模型先适应梯度景观。
余弦衰减公式
$$\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\frac{t - t_{\text{warmup}}}{T - t_{\text{warmup}}}\pi\right)$$
其中 $t_{\text{warmup}}$ 是热身步数,$T$ 是总步数,$\eta_{\max}$ 是峰值学习率,$\eta_{\min}$ 通常设为 $\eta_{\max} / 10$ 或 0。
完整实现
def get_lr(step, warmup_steps, max_steps, max_lr=3e-4, min_lr=3e-5):
if step < warmup_steps:
return max_lr * (step + 1) / warmup_steps
if step >= max_steps:
return min_lr
progress = (step - warmup_steps) / (max_steps - warmup_steps)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
梯度裁剪
除了学习率调度,梯度裁剪也是防止训练爆炸的关键手段:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
按全局梯度范数裁剪到 max_norm。这不会改变梯度方向,只缩小过大的梯度。GPT-3、Llama 等大规模训练都使用 1.0 作为默认值。
2. 加载 GPT-2 权重:从 TensorFlow 到 PyTorch [MLNLP]
权重格式转换
OpenAI 公开的 GPT-2 权重是 TensorFlow checkpoint 格式。加载流程:
- 下载:从 OpenAI 的 Google Cloud Storage 下载
.h5文件 - 读取:用
tensorflow或h5py读取权重数组 - 映射:将 TF 变量名映射到我们自定义的 PyTorch 模型
关键映射关系:
| TensorFlow 名称 | PyTorch 对应 |
|---|---|
lm_head.weight / wte | tok_emb.weight |
wpe | pos_emb.weight |
h.{i}.ln_1.{g,b} | blocks.{i}.norm1.{shift,scale} |
h.{i}.attn.{c_attn,c_proj}.{w,b} | blocks.{i}.attn 中的 QKV 投影和输出投影 |
h.{i}.ln_2.{g,b} | blocks.{i}.norm2.{shift,scale} |
h.{i}.mlp.{c_fc,c_proj}.{w,b} | blocks.{i}.ff 的两层线性 |
权重转置陷阱
TF 的 Dense 层权重形状为 (in, out),而 PyTorch 的 nn.Linear 权重形状为 (out, in)。加载时必须转置:
def load_tf_weights(tf_params, model):
for name, param in model.named_parameters():
tf_name = name_to_tf(name)
arr = tf_params[tf_name]
if arr.ndim == 2:
arr = arr.T # 关键:转置
param.data = torch.from_numpy(arr)
替代权重来源 [MLNLP]
如果 OpenAI 的官方权重不可用,MLNLP 项目的 ch05/02_alternative_weight_loading 提供了从 Hugging Face Hub 加载的替代方案。Hugging Face 的 transformers 库封装了多种格式的 GPT-2 权重,可以通过 from_pretrained 一行加载。
3. 高效权重加载 [MLNLP]
当模型参数量较大时(如 1.5B 的 GPT-2 XL),直接 load_state_dict 会先在内存中创建一份完整的参数副本,导致峰值内存接近模型大小的 2 倍。
MLNLP 的 ch05/08_memory_efficient_weight_loading 展示了更高效的方法:
# 方法1:逐层加载,避免同时持有两份权重
state_dict = torch.load("weights.pt", map_location="cpu")
for name, param in model.named_parameters():
param.data = state_dict[name]
# 方法2:使用 mmap 模式加载,不一次性读入内存
state_dict = torch.load("weights.pt", map_location="cpu", mmap=True)
方法 2(mmap 模式)在 PyTorch 2.1+ 可用,适合加载超大模型的检查点。
4. GPT → Llama 架构转换 [MLNLP]
这是 MLNLP 项目中最有价值的补充材料之一(ch05/07_gpt_to_llama),涵盖了从 GPT-2 到现代 Llama 3.2 架构的逐步演进。
架构演进路径
GPT-2 → Llama 2 → Llama 3/3.1 → Llama 3.2
每一步引入的关键改动:
4.1 GELU → SwiGLU
# GPT-2 的 FFN
class GPT_FFN:
x → Linear(d, 4d) → GELU → Linear(4d, d)
# Llama 的 FFN(SwiGLU)
class Llama_FFN:
x → Linear(d, 4d) → SiLU ┐
x → Linear(d, 4d) ├→ ⊙ → Linear(4d, d)
SwiGLU 多了一个门控分支,用 SiLU($\text{SiLU}(x) = x \cdot \sigma(x)$)作为激活函数。
4.2 绝对位置编码 → RoPE(旋转位置编码)
GPT-2 使用可学习的绝对位置嵌入(nn.Embedding(max_len, d))。Llama 使用旋转位置编码(RoPE),它通过对 query 和 key 向量施加旋转矩阵来编码相对位置:
- 不受最大序列长度限制(可外推)
- 不增加额外参数
- 相对位置关系直接体现在点积中
4.3 LayerNorm → RMSNorm
RMSNorm 是 LayerNorm 的简化版,去掉了均值中心化,只做缩放:
$$\text{RMSNorm}(x) = \frac{x}{\sqrt{\text{mean}(x^2) + \varepsilon}} \odot \gamma$$
省掉了均值计算,速度更快,且实验表明效果不逊于 LayerNorm。
4.4 多头注意力 → GQA(分组查询注意力)
Llama 2 引入了 GQA:多组 query 共享同一组 key/value,减少 KV 缓存的内存占用,加速推理。GPT-2 的每个 head 有独立的 Q/K/V,Llama 2 中每 $n$ 个 query head 共享一组 K/V。
5. 在 Gutenberg 语料库上预训练 [MLNLP]
MLNLP 的 ch05/03_bonus_pretraining_on_gutenberg 提供了使用 Project Gutenberg 全部书籍语料做更长时间预训练的代码。
核心要点:
- Gutenberg 提供了约 7 万本公版书籍,总计约 10GB 文本
- 用 124M GPT-2 在此语料上预训练数个 epoch,可以观察到 loss 持续下降
- 主要用于教学目的——验证训练循环的正确性
6. 超参数调优 [MLNLP]
ch05/05_bonus_hparam_tuning 提供了自动超参数搜索脚本。关键参数的推荐范围:
| 超参数 | 搜索范围 | 推荐 |
|---|---|---|
| 学习率 | $1\text{e-}4 \sim 6\text{e-}4$ | $3\text{e-}4$ |
| weight_decay | $0.01 \sim 0.2$ | $0.1$ |
| batch_size | $4 \sim 32$ | 尽量大,用梯度累积 |
| warmup_steps | 总步数的 $0.1% \sim 2%$ | $0.5%$ |
| gradient_clip | $0.5 \sim 2.0$ | $1.0$ |
7. 扩展阅读 [MLNLP]
- 学习率调度器:
ch05/04_learning_rate_schedulers包含线性热身 + 余弦衰减的完整实现 - GPT → Llama 转换系列:
ch05/07_gpt_to_llama/下的三个 notebook 逐步演示了架构演进 - 交互式界面:
ch05/06_user_interface提供了与预训练模型对话的 Gradio 界面 - 附录 D:更详尽的训练循环优化策略(梯度累积、混合精度等)