第4章补充材料:GPT架构深度剖析

1. Pre-LN vs Post-LN:为什么 GPT 选择了 Pre-LN [MLNLP]

主章提到了 GPT-2 采用 Pre-LN(子层之前做 LayerNorm),但未深究原因。这一选择对训练稳定性至关重要。

数学对比

Post-LN(原始 Transformer):

$$x_{l+1} = \text{LayerNorm}(x_l + \text{Sublayer}(x_l))$$

Pre-LN(GPT-2 及之后几乎所有 LLM):

$$x_{l+1} = x_l + \text{Sublayer}(\text{LayerNorm}(x_l))$$

为什么 Pre-LN 更稳定

Post-LN 中,每个残差分支的输出都要经过 LayerNorm 后才汇入主干,这意味着梯度必须穿过 LayerNorm 才能回传到浅层。而 LayerNorm 的梯度大小与输入方差有关——深层网络中方差波动剧烈,导致梯度尺度不稳定。

Pre-LN 中,残差连接 $x_{l+1} = x_l + \Delta_l$ 保证了一条不经过任何归一化的梯度高速公路

$$\frac{\partial \mathcal{L}}{\partial x_0} = \frac{\partial \mathcal{L}}{\partial x_L} \prod_{l=0}^{L-1}\left(I + \frac{\partial \Delta_l}{\partial x_l}\right)$$

即使 $\Delta_l$ 的梯度消失,恒等项 $I$ 保证了梯度至少为 1。

实际影响

有研究表明,在相同超参数下:

  • Post-LN 需要学习率热身才能不炸,且对学习率敏感;
  • Pre-LN 可以用更大的学习率,且不一定需要热身;
  • Pre-LN 的 loss 曲线更平滑,训练初期的梯度范数更可控。

⚠️ 代价:Pre-LN 在某些工作中报告了略差的最终性能。这是一个活跃的研究话题。实践中几乎所有大模型(GPT-3/4、Llama、Mistral)都用 Pre-LN,说明稳定性收益远大于这点性能损失。


2. GELU 激活函数:不只是”平滑的 ReLU” [MLNLP]

从直觉到公式

ReLU 的行为是二值决策:$x > 0$ 就保留,否则丢弃。GELU 用标准正态分布的累积分布函数(CDF)做”概率性保留”:

$$\text{GELU}(x) = x \cdot \Phi(x) = x \cdot P(Z \leq x), \quad Z \sim \mathcal{N}(0, 1)$$

直觉:输入越大,被保留的概率越高;输入越小(负得越多),被丢弃的概率越高。但和 ReLU 不同的是,这个”保留/丢弃”是连续的、可微的。

近似计算

精确的 $\Phi(x)$ 涉及误差函数(erf),计算开销不小。GPT-2 和 BERT 都使用 tanh 近似:

$$\text{GELU}(x) \approx 0.5x\left(1 + \tanh\left[\sqrt{2/\pi}(x + 0.044715x^3)\right]\right)$$

PyTorch 中 torch.nn.GELU(approximate='tanh') 即为此版本。

GELU vs SwiGLU

近年来 Llama 系列用 SwiGLU 替代 GELU:

$$\text{SwiGLU}(x, W, V) = (\text{SiLU}(xW) \odot xV)$$

SwiGLU 多了一组门控权重,表达力更强但参数量也更多(FFN 从 2 个矩阵变成 3 个)。Llama 2/3、Mistral 等都采用此设计。


3. 残差连接:深层网络的”生命线” [MLNLP]

梯度流分析

假设一个 L 层网络,每层为 $x_{l+1} = x_l + F_l(x_l)$,其中 $F_l$ 是子层(注意力或 FFN)。反向传播时:

$$\frac{\partial \mathcal{L}}{\partial x_l} = \frac{\partial \mathcal{L}}{\partial x_L}\prod_{k=l}^{L-1}\left(I + \frac{\partial F_k}{\partial x_k}\right)$$

没有残差连接时,连乘项只有 $\partial F_k / \partial x_k$,如果每层缩小梯度(比如因子 0.8),12 层后梯度只剩 $0.8^{12} \approx 0.07$。

有了残差连接,即使 $\partial F_k / \partial x_k$ 接近 0,$(I + \partial F_k / \partial x_k)$ 仍然接近 $I$,梯度不会衰减到零

恒等初始化视角

如果子层 $F_l$ 初始化为接近零的输出(大多数 Transformer 的默认初始化),那么 $x_{l+1} \approx x_l$。整个网络在训练开始时几乎是一个恒等映射。这使得:

  1. 损失函数的初始值接近于浅层网络的损失,容易优化;
  2. 随着训练进行,网络逐步”学会”利用深层结构。

4. GPT 模型的 FLOPs 分析 [MLNLP]

MLNLP 项目提供了详细的浮点运算(FLOPs)分析工具,这里给出核心结论。

计算公式

对于 GPT-2 124M($d=768, n_h=12, n_l=12, V=50257, T=1024$):

单次前向传播的 FLOPs

组件计算FLOPs
Token Embed查表$\approx 0$
每层 Multi-Head Attention$4BTd^2 + 2BT^2d$$6.0 \times 10^9$
每层 FFN$8BTd^2$(expansion=4)$4.8 \times 10^9$
12 层合计$1.30 \times 10^{11}$
输出投影$2BTVd$$7.7 \times 10^{10}$
总计(B=1)$\approx 2.1 \times 10^{11}$

训练的总计算量

训练时还需要反向传播(约为前向的 2 倍),所以一次完整的 forward+backward:

$$\text{FLOPs}{\text{train}} \approx 6 \times \text{FLOPs}{\text{forward}}$$

GPT-2 124M 训练 8B tokens(约 8M 步 × batch 1024):

$$\text{总 FLOPs} \approx 6 \times 2.1 \times 10^{11} \times 8 \times 10^9 \approx 1.0 \times 10^{21}$$

在单张 A100(312 TFLOPS BF16)上约需 37 天,8 卡 A100 约 4.6 天。

参数效率

值得注意的是,124M 模型中:

  • 嵌入矩阵(token + position)约 $39\text{M}$ 参数,占总量的 24%;
  • 12 个 Transformer Block 约 $85\text{M}$ 参数;
  • 输出投影(如不与嵌入共享)约 $39\text{M}$ 参数。

权重共享(嵌入 = 输出投影)可以节省约 24% 的参数,同时对性能影响很小。


5. GPT-2 的完整配置对比 [Datawhale]

OpenAI 发布了四个尺寸的 GPT-2,它们的配置如下:

模型参数量$d_{\text{model}}$$n_{\text{heads}}$$n_{\text{layers}}$$d_{\text{ff}}$
GPT-2 Small124M76812123072
GPT-2 Medium355M102416244096
GPT-2 Large774M128020365120
GPT-2 XL1558M160025486400

规律:深度(层数)的增长快于宽度。从 Small 到 XL,宽度翻倍,深度翻 4 倍。这一趋势在后续的 GPT-3 和 Llama 系列中更加明显。


6. 扩展阅读 [MLNLP]

  • 性能分析笔记本:MLNLP 项目 ch04/02_performance-analysis/flops-analysis.ipynb 提供了交互式的 FLOPs 计算工具,可以自定义模型配置进行分析。
  • GPT → Llama 转换指南ch05/07_gpt_to_llama/ 包含了将第 4 章实现的 GPT 架构逐步改造为 Llama 3.2 架构的教程,涵盖 RoPE、SwiGLU、RMSNorm 等现代组件的引入。