第4章补充材料:GPT架构深度剖析
1. Pre-LN vs Post-LN:为什么 GPT 选择了 Pre-LN [MLNLP]
主章提到了 GPT-2 采用 Pre-LN(子层之前做 LayerNorm),但未深究原因。这一选择对训练稳定性至关重要。
数学对比
Post-LN(原始 Transformer):
$$x_{l+1} = \text{LayerNorm}(x_l + \text{Sublayer}(x_l))$$
Pre-LN(GPT-2 及之后几乎所有 LLM):
$$x_{l+1} = x_l + \text{Sublayer}(\text{LayerNorm}(x_l))$$
为什么 Pre-LN 更稳定
Post-LN 中,每个残差分支的输出都要经过 LayerNorm 后才汇入主干,这意味着梯度必须穿过 LayerNorm 才能回传到浅层。而 LayerNorm 的梯度大小与输入方差有关——深层网络中方差波动剧烈,导致梯度尺度不稳定。
Pre-LN 中,残差连接 $x_{l+1} = x_l + \Delta_l$ 保证了一条不经过任何归一化的梯度高速公路:
$$\frac{\partial \mathcal{L}}{\partial x_0} = \frac{\partial \mathcal{L}}{\partial x_L} \prod_{l=0}^{L-1}\left(I + \frac{\partial \Delta_l}{\partial x_l}\right)$$
即使 $\Delta_l$ 的梯度消失,恒等项 $I$ 保证了梯度至少为 1。
实际影响
有研究表明,在相同超参数下:
- Post-LN 需要学习率热身才能不炸,且对学习率敏感;
- Pre-LN 可以用更大的学习率,且不一定需要热身;
- Pre-LN 的 loss 曲线更平滑,训练初期的梯度范数更可控。
⚠️ 代价:Pre-LN 在某些工作中报告了略差的最终性能。这是一个活跃的研究话题。实践中几乎所有大模型(GPT-3/4、Llama、Mistral)都用 Pre-LN,说明稳定性收益远大于这点性能损失。
2. GELU 激活函数:不只是”平滑的 ReLU” [MLNLP]
从直觉到公式
ReLU 的行为是二值决策:$x > 0$ 就保留,否则丢弃。GELU 用标准正态分布的累积分布函数(CDF)做”概率性保留”:
$$\text{GELU}(x) = x \cdot \Phi(x) = x \cdot P(Z \leq x), \quad Z \sim \mathcal{N}(0, 1)$$
直觉:输入越大,被保留的概率越高;输入越小(负得越多),被丢弃的概率越高。但和 ReLU 不同的是,这个”保留/丢弃”是连续的、可微的。
近似计算
精确的 $\Phi(x)$ 涉及误差函数(erf),计算开销不小。GPT-2 和 BERT 都使用 tanh 近似:
$$\text{GELU}(x) \approx 0.5x\left(1 + \tanh\left[\sqrt{2/\pi}(x + 0.044715x^3)\right]\right)$$
PyTorch 中 torch.nn.GELU(approximate='tanh') 即为此版本。
GELU vs SwiGLU
近年来 Llama 系列用 SwiGLU 替代 GELU:
$$\text{SwiGLU}(x, W, V) = (\text{SiLU}(xW) \odot xV)$$
SwiGLU 多了一组门控权重,表达力更强但参数量也更多(FFN 从 2 个矩阵变成 3 个)。Llama 2/3、Mistral 等都采用此设计。
3. 残差连接:深层网络的”生命线” [MLNLP]
梯度流分析
假设一个 L 层网络,每层为 $x_{l+1} = x_l + F_l(x_l)$,其中 $F_l$ 是子层(注意力或 FFN)。反向传播时:
$$\frac{\partial \mathcal{L}}{\partial x_l} = \frac{\partial \mathcal{L}}{\partial x_L}\prod_{k=l}^{L-1}\left(I + \frac{\partial F_k}{\partial x_k}\right)$$
没有残差连接时,连乘项只有 $\partial F_k / \partial x_k$,如果每层缩小梯度(比如因子 0.8),12 层后梯度只剩 $0.8^{12} \approx 0.07$。
有了残差连接,即使 $\partial F_k / \partial x_k$ 接近 0,$(I + \partial F_k / \partial x_k)$ 仍然接近 $I$,梯度不会衰减到零。
恒等初始化视角
如果子层 $F_l$ 初始化为接近零的输出(大多数 Transformer 的默认初始化),那么 $x_{l+1} \approx x_l$。整个网络在训练开始时几乎是一个恒等映射。这使得:
- 损失函数的初始值接近于浅层网络的损失,容易优化;
- 随着训练进行,网络逐步”学会”利用深层结构。
4. GPT 模型的 FLOPs 分析 [MLNLP]
MLNLP 项目提供了详细的浮点运算(FLOPs)分析工具,这里给出核心结论。
计算公式
对于 GPT-2 124M($d=768, n_h=12, n_l=12, V=50257, T=1024$):
单次前向传播的 FLOPs:
| 组件 | 计算 | FLOPs |
|---|---|---|
| Token Embed | 查表 | $\approx 0$ |
| 每层 Multi-Head Attention | $4BTd^2 + 2BT^2d$ | $6.0 \times 10^9$ |
| 每层 FFN | $8BTd^2$(expansion=4) | $4.8 \times 10^9$ |
| 12 层合计 | $1.30 \times 10^{11}$ | |
| 输出投影 | $2BTVd$ | $7.7 \times 10^{10}$ |
| 总计(B=1) | $\approx 2.1 \times 10^{11}$ |
训练的总计算量
训练时还需要反向传播(约为前向的 2 倍),所以一次完整的 forward+backward:
$$\text{FLOPs}{\text{train}} \approx 6 \times \text{FLOPs}{\text{forward}}$$
GPT-2 124M 训练 8B tokens(约 8M 步 × batch 1024):
$$\text{总 FLOPs} \approx 6 \times 2.1 \times 10^{11} \times 8 \times 10^9 \approx 1.0 \times 10^{21}$$
在单张 A100(312 TFLOPS BF16)上约需 37 天,8 卡 A100 约 4.6 天。
参数效率
值得注意的是,124M 模型中:
- 嵌入矩阵(token + position)约 $39\text{M}$ 参数,占总量的 24%;
- 12 个 Transformer Block 约 $85\text{M}$ 参数;
- 输出投影(如不与嵌入共享)约 $39\text{M}$ 参数。
权重共享(嵌入 = 输出投影)可以节省约 24% 的参数,同时对性能影响很小。
5. GPT-2 的完整配置对比 [Datawhale]
OpenAI 发布了四个尺寸的 GPT-2,它们的配置如下:
| 模型 | 参数量 | $d_{\text{model}}$ | $n_{\text{heads}}$ | $n_{\text{layers}}$ | $d_{\text{ff}}$ |
|---|---|---|---|---|---|
| GPT-2 Small | 124M | 768 | 12 | 12 | 3072 |
| GPT-2 Medium | 355M | 1024 | 16 | 24 | 4096 |
| GPT-2 Large | 774M | 1280 | 20 | 36 | 5120 |
| GPT-2 XL | 1558M | 1600 | 25 | 48 | 6400 |
规律:深度(层数)的增长快于宽度。从 Small 到 XL,宽度翻倍,深度翻 4 倍。这一趋势在后续的 GPT-3 和 Llama 系列中更加明显。
6. 扩展阅读 [MLNLP]
- 性能分析笔记本:MLNLP 项目
ch04/02_performance-analysis/flops-analysis.ipynb提供了交互式的 FLOPs 计算工具,可以自定义模型配置进行分析。 - GPT → Llama 转换指南:
ch05/07_gpt_to_llama/包含了将第 4 章实现的 GPT 架构逐步改造为 Llama 3.2 架构的教程,涵盖 RoPE、SwiGLU、RMSNorm 等现代组件的引入。