第3章补充材料:注意力机制扩展

本文是「第3章:注意力机制」的补充阅读材料,内容综合自 [MLNLP-World 中文翻译项目] 和 [Datawhale 中文翻译项目],对核心概念做更深度的直觉解释和数学展开。标注 [MLNLP] 表示来自 MLNLP-World 项目,[Datawhale] 表示来自 Datawhale 项目。

1. 为什么需要注意力机制?——从翻译问题说起 [Datawhale]

在 Transformer 出现之前,机器翻译的主流方案是编码器-解码器 RNN。编码器逐词读入源语言句子,不断更新隐藏状态,最终把整个句子的含义压缩进一个固定大小的向量。解码器只能从这个最终隐藏状态出发,逐词生成译文。

这就有问题了:句子一长,编码器不可能把所有信息塞进一个向量。尤其是远距离依赖——比如德语中动词经常在句末,但翻译成英语时它对应句首的主语——RNN 很容易丢失这种关系。

注意力机制的诞生动机就是打破这个瓶颈:让解码器在每一步都能”回头看”编码器的所有隐藏状态,按需取用信息,而不是被迫从一个压缩向量里猜。 [Datawhale]

虽然 GPT 只用了解码器一侧,但自注意力的核心思想是一样的——每个位置都可以直接和序列中其他位置交互,不需要信息经过漫长的递归传递。

2. 注意力机制的直觉类比 [Datawhale] [MLNLP]

2.1 信息检索类比

Q/K/V 的命名来自信息检索系统:

  • Query(查询):你在搜索框里输入的关键词,代表”我想要什么”。
  • Key(键):数据库中每条记录的索引标签,代表”我这里有什么”。
  • Value(值):记录的实际内容。

Query 和 Key 做匹配(点积 → 相似度),匹配度越高,对应的 Value 被赋予越大的权重。最终输出是所有 Value 的加权和。 [Datawhale]

2.2 鸡尾酒会类比

想象你在一个嘈鸡尾酒会上:

  • 你的Query 是”谁在叫我名字?”
  • 每个人的声音是一对 Key-Value:Key 是音色/方向(用于判断相关性),Value 是实际说的话。
  • 你的大脑自动给不同声源分配不同的注意力权重,滤掉无关的噪音,聚焦于和你相关的对话。

自注意力就是让序列中的每个 token 都像酒会上的一个人,同时既在”说话”(提供 Key 和 Value),又在”倾听”(发出 Query)。 [MLNLP]

3. 自注意力的完整数学推导 [Datawhale]

3.1 从简化版到完整版

最简化版的自注意力直接用输入向量做点积:

$$\omega_{ij} = x_i^\top x_j, \quad \alpha_{ij} = \frac{\exp(\omega_{ij})}{\sum_k \exp(\omega_{ik})}, \quad z_i = \sum_j \alpha_{ij} x_j$$

问题:没有可学习参数,无法调整注意力的”视角”。 [Datawhale]

3.2 引入可训练权重矩阵

为每个输入 $x_i$ 计算三个投影向量:

$$q_i = x_i W_Q, \quad k_i = x_i W_K, \quad v_i = x_i W_V$$

其中 $W_Q, W_K, W_V \in \mathbb{R}^{d_{in} \times d_{out}}$ 是可训练参数。

批量矩阵形式(处理整个序列):

$$Q = X W_Q, \quad K = X W_K, \quad V = X W_V$$

其中 $X \in \mathbb{R}^{T \times d_{in}}$。 [Datawhale]

3.3 缩放点积注意力的完整公式

$$\text{Attention}(Q, K, V) = \text{softmax}!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V$$

逐步展开:

  1. 计算注意力得分:$\omega = QK^\top$,形状 $(T \times T)$
  2. 缩放:$\omega’ = \omega / \sqrt{d_k}$
  3. 归一化:$\alpha = \text{softmax}(\omega’)$,每行之和为 1
  4. 加权聚合:$Z = \alpha V$,形状 $(T \times d_{out})$

3.4 为什么要除以 $\sqrt{d_k}$?——数学深入 [Datawhale]

假设 $q$ 和 $k$ 的每个分量独立采样自均值为 0、方差为 1 的分布,则点积 $q \cdot k = \sum_{i=1}^{d_k} q_i k_i$ 的均值为 0、方差为 $d_k$

当 $d_k = 1024$(GPT-2 的典型值)时,未缩放点积的方差就是 1024,标准差约 32。这么大的值喂给 softmax 会产生接近 one-hot 的分布:

$$\text{softmax}([32, 0, 0, \dots]) \approx [1.0, 0.0, 0.0, \dots]$$

梯度几乎为零——模型学不到东西。除以 $\sqrt{d_k}$ 把方差拉回 1,softmax 的输入分布在合理范围内,梯度流动正常。 [Datawhale]

这就是为什么这种机制被叫做”缩放点积注意力”(Scaled Dot-Product Attention)。

4. 因果掩码的作用与细节 [Datawhale] [MLNLP]

4.1 为什么必须遮住未来?

GPT 的预训练任务是自回归的:给定 $x_1, \dots, x_t$,预测 $x_{t+1}$。如果第 $t$ 个位置能看到 $x_{t+1}$(即答案),模型会走捷径——直接抄答案而不是学习有用的模式。训练损失虚低,实际能力为零。 [Datawhale]

4.2 实现方式

在 softmax 之前,把上三角位置(不含对角线)的得分设为 $-\infty$:

mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
weights = torch.softmax(scores, dim=-1)  # 上三角自动变成 0

为什么用 $-\infty$ 而不是 0?因为 softmax 的公式是 $\frac{e^{x_i}}{\sum_j e^{x_j}}$,只有当 $x_i \to -\infty$ 时 $e^{x_i} \to 0$,对应的权重才严格为零。如果直接设为 0,$e^0 = 1 \neq 0$,“未来”位置仍会泄漏信息。 [MLNLP]

4.3 register_buffer 的意义 [MLNLP]

掩码不是模型参数(不参与梯度更新),但需要随模型一起移动到 GPU、一起保存到 checkpoint。register_buffer 正好满足这个需求——它是”不是参数但需要随模型走”的张量的标准做法。 [MLNLP]

5. 多头注意力的动机与实现 [Datawhale] [MLNLP]

5.1 为什么需要多头?

单头注意力只能从一种视角看待序列关系。但语言中的关系是多样的:

  • 语法关系:主语和谓语的一致性
  • 共指关系:“他”指代前文的”张三”
  • 语义关系:近义词、反义词的关联

多头注意力让模型同时学习多种关系模式,每个头专注于不同方面。 [Datawhale]

5.2 实现策略

把 $d_{model}$ 的维度等分给 $h$ 个头,每个头的维度为 $d_k = d_{model} / h$。各头独立计算注意力后拼接:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W_O$$

其中 $\text{head}_i = \text{Attention}(Q W_Q^{(i)}, K W_K^{(i)}, V W_V^{(i)})$。

5.3 张量形状变化详解

以 GPT-2 Small 为例:$d_{model}=768$, $h=12$, $d_k=64$。

步骤操作形状
输入$X$$(B, T, 768)$
线性投影$Q = XW_Q$$(B, T, 768)$
切头view(B, T, 12, 64)$(B, T, 12, 64)$
转置transpose(1, 2)$(B, 12, T, 64)$
注意力得分$QK^\top$$(B, 12, T, T)$
输出$\alpha V$$(B, 12, T, 64)$
合头transpose + view$(B, T, 768)$
最终投影$\times W_O$$(B, T, 768)$

口诀:“切头→转轴→算注意力→回轴→合头→过 $W_O$”。 [MLNLP]

6. 高效多头注意力的优化技巧 [MLNLP]

6.1 nn.Linear vs nn.Parameter

原书展示了两种实现:

  • v1:用 nn.Parameter(torch.rand(d_in, d_out)) 手动定义权重
  • v2:用 nn.Linear(d_in, d_out, bias=False)

v2 更好,原因:

  1. nn.Linear 使用了优化的权重初始化方案(Kaiming/Xavier),训练更稳定
  2. 注意 nn.Linear 内部以转置形式存储权重:weight.shape = (d_out, d_in),所以前向传播是 x @ W.T 而非 x @ W
  3. 计算效率更高,融合了更多底层优化 [Datawhale]

6.2 PyTorch 内置的融合注意力

生产环境中不要自己写注意力——使用 PyTorch 2.0+ 的融合内核:

import torch.nn.functional as F

# 自动选择 FlashAttention / Memory-Efficient Attention / 数学实现
output = F.scaled_dot_product_attention(Q, K, V, attn_mask=causal_mask)

优势:

  • FlashAttention:通过分块计算避免实例化完整的 $T \times T$ 注意力矩阵,显存从 $O(T^2)$ 降到 $O(T)$
  • IO 感知:减少 HBM(高带宽内存)读写次数,实际速度可提升 2-4×
  • 数值稳定:内置 softmax 数值稳定性处理 [MLNLP]

6.3 DropOut 在注意力中的作用

在注意力权重上加 Dropout(通常 p=0.1~0.25),随机将部分权重置零。作用:

  • 防止模型过度依赖少数特定位置
  • 增加训练时的噪声,起到正则化效果
  • 推理时自动关闭 [Datawhale]

7. 注意力机制的四种变体演变路线 [Datawhale]

本章实际实现了四种逐步增强的注意力变体,理解这个递进关系有助于整体把握:

变体新增特性解决的问题
简化自注意力基本点积 + softmax演示核心思想
带权重的自注意力$W_Q, W_K, W_V$引入可学习参数
因果注意力上三角掩码防止看到未来
多头注意力并行多组注意力捕捉多种关系模式

每一层都是在前一层基础上的最小增量修改,最终版本就是我们塞进 GPT 架构的模块。 [Datawhale]

8. 权重参数 vs 注意力权重 —— 别搞混 [Datawhale]

这个区分看似简单但经常被混淆:

  • 权重参数(Weight Parameters):$W_Q, W_K, W_V$ 这些矩阵,是神经网络的静态学习参数,通过反向传播更新。它们定义了网络连接的基本属性。
  • 注意力权重(Attention Weights):$\alpha_{ij}$,是动态的、随输入变化的值。不同的输入序列会产生不同的注意力权重分布。

前者是”学到的知识”,后者是”应用知识的方式”。 [Datawhale]

延伸阅读

  • MLNLP-World 中文翻译项目还提供了 [高效多头注意力的不同实现变体对比] 和 [PyTorch Buffer 概念详解] 等补充材料 [MLNLP]
  • 原书练习 3.1:尝试将 SelfAttention_v2 的权重迁移到 SelfAttention_v1,验证输出一致(提示:注意 nn.Linear 以转置形式存储权重)[Datawhale]

← 返回第3章 · 返回目录 · 下一章 · 从零实现 GPT →