第3章补充材料:注意力机制扩展
本文是「第3章:注意力机制」的补充阅读材料,内容综合自 [MLNLP-World 中文翻译项目] 和 [Datawhale 中文翻译项目],对核心概念做更深度的直觉解释和数学展开。标注 [MLNLP] 表示来自 MLNLP-World 项目,[Datawhale] 表示来自 Datawhale 项目。
1. 为什么需要注意力机制?——从翻译问题说起 [Datawhale]
在 Transformer 出现之前,机器翻译的主流方案是编码器-解码器 RNN。编码器逐词读入源语言句子,不断更新隐藏状态,最终把整个句子的含义压缩进一个固定大小的向量。解码器只能从这个最终隐藏状态出发,逐词生成译文。
这就有问题了:句子一长,编码器不可能把所有信息塞进一个向量。尤其是远距离依赖——比如德语中动词经常在句末,但翻译成英语时它对应句首的主语——RNN 很容易丢失这种关系。
注意力机制的诞生动机就是打破这个瓶颈:让解码器在每一步都能”回头看”编码器的所有隐藏状态,按需取用信息,而不是被迫从一个压缩向量里猜。 [Datawhale]
虽然 GPT 只用了解码器一侧,但自注意力的核心思想是一样的——每个位置都可以直接和序列中其他位置交互,不需要信息经过漫长的递归传递。
2. 注意力机制的直觉类比 [Datawhale] [MLNLP]
2.1 信息检索类比
Q/K/V 的命名来自信息检索系统:
- Query(查询):你在搜索框里输入的关键词,代表”我想要什么”。
- Key(键):数据库中每条记录的索引标签,代表”我这里有什么”。
- Value(值):记录的实际内容。
Query 和 Key 做匹配(点积 → 相似度),匹配度越高,对应的 Value 被赋予越大的权重。最终输出是所有 Value 的加权和。 [Datawhale]
2.2 鸡尾酒会类比
想象你在一个嘈鸡尾酒会上:
- 你的Query 是”谁在叫我名字?”
- 每个人的声音是一对 Key-Value:Key 是音色/方向(用于判断相关性),Value 是实际说的话。
- 你的大脑自动给不同声源分配不同的注意力权重,滤掉无关的噪音,聚焦于和你相关的对话。
自注意力就是让序列中的每个 token 都像酒会上的一个人,同时既在”说话”(提供 Key 和 Value),又在”倾听”(发出 Query)。 [MLNLP]
3. 自注意力的完整数学推导 [Datawhale]
3.1 从简化版到完整版
最简化版的自注意力直接用输入向量做点积:
$$\omega_{ij} = x_i^\top x_j, \quad \alpha_{ij} = \frac{\exp(\omega_{ij})}{\sum_k \exp(\omega_{ik})}, \quad z_i = \sum_j \alpha_{ij} x_j$$
问题:没有可学习参数,无法调整注意力的”视角”。 [Datawhale]
3.2 引入可训练权重矩阵
为每个输入 $x_i$ 计算三个投影向量:
$$q_i = x_i W_Q, \quad k_i = x_i W_K, \quad v_i = x_i W_V$$
其中 $W_Q, W_K, W_V \in \mathbb{R}^{d_{in} \times d_{out}}$ 是可训练参数。
批量矩阵形式(处理整个序列):
$$Q = X W_Q, \quad K = X W_K, \quad V = X W_V$$
其中 $X \in \mathbb{R}^{T \times d_{in}}$。 [Datawhale]
3.3 缩放点积注意力的完整公式
$$\text{Attention}(Q, K, V) = \text{softmax}!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V$$
逐步展开:
- 计算注意力得分:$\omega = QK^\top$,形状 $(T \times T)$
- 缩放:$\omega’ = \omega / \sqrt{d_k}$
- 归一化:$\alpha = \text{softmax}(\omega’)$,每行之和为 1
- 加权聚合:$Z = \alpha V$,形状 $(T \times d_{out})$
3.4 为什么要除以 $\sqrt{d_k}$?——数学深入 [Datawhale]
假设 $q$ 和 $k$ 的每个分量独立采样自均值为 0、方差为 1 的分布,则点积 $q \cdot k = \sum_{i=1}^{d_k} q_i k_i$ 的均值为 0、方差为 $d_k$。
当 $d_k = 1024$(GPT-2 的典型值)时,未缩放点积的方差就是 1024,标准差约 32。这么大的值喂给 softmax 会产生接近 one-hot 的分布:
$$\text{softmax}([32, 0, 0, \dots]) \approx [1.0, 0.0, 0.0, \dots]$$
梯度几乎为零——模型学不到东西。除以 $\sqrt{d_k}$ 把方差拉回 1,softmax 的输入分布在合理范围内,梯度流动正常。 [Datawhale]
这就是为什么这种机制被叫做”缩放点积注意力”(Scaled Dot-Product Attention)。
4. 因果掩码的作用与细节 [Datawhale] [MLNLP]
4.1 为什么必须遮住未来?
GPT 的预训练任务是自回归的:给定 $x_1, \dots, x_t$,预测 $x_{t+1}$。如果第 $t$ 个位置能看到 $x_{t+1}$(即答案),模型会走捷径——直接抄答案而不是学习有用的模式。训练损失虚低,实际能力为零。 [Datawhale]
4.2 实现方式
在 softmax 之前,把上三角位置(不含对角线)的得分设为 $-\infty$:
mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
weights = torch.softmax(scores, dim=-1) # 上三角自动变成 0
为什么用 $-\infty$ 而不是 0?因为 softmax 的公式是 $\frac{e^{x_i}}{\sum_j e^{x_j}}$,只有当 $x_i \to -\infty$ 时 $e^{x_i} \to 0$,对应的权重才严格为零。如果直接设为 0,$e^0 = 1 \neq 0$,“未来”位置仍会泄漏信息。 [MLNLP]
4.3 register_buffer 的意义 [MLNLP]
掩码不是模型参数(不参与梯度更新),但需要随模型一起移动到 GPU、一起保存到 checkpoint。register_buffer 正好满足这个需求——它是”不是参数但需要随模型走”的张量的标准做法。 [MLNLP]
5. 多头注意力的动机与实现 [Datawhale] [MLNLP]
5.1 为什么需要多头?
单头注意力只能从一种视角看待序列关系。但语言中的关系是多样的:
- 语法关系:主语和谓语的一致性
- 共指关系:“他”指代前文的”张三”
- 语义关系:近义词、反义词的关联
多头注意力让模型同时学习多种关系模式,每个头专注于不同方面。 [Datawhale]
5.2 实现策略
把 $d_{model}$ 的维度等分给 $h$ 个头,每个头的维度为 $d_k = d_{model} / h$。各头独立计算注意力后拼接:
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W_O$$
其中 $\text{head}_i = \text{Attention}(Q W_Q^{(i)}, K W_K^{(i)}, V W_V^{(i)})$。
5.3 张量形状变化详解
以 GPT-2 Small 为例:$d_{model}=768$, $h=12$, $d_k=64$。
| 步骤 | 操作 | 形状 |
|---|---|---|
| 输入 | $X$ | $(B, T, 768)$ |
| 线性投影 | $Q = XW_Q$ | $(B, T, 768)$ |
| 切头 | view(B, T, 12, 64) | $(B, T, 12, 64)$ |
| 转置 | transpose(1, 2) | $(B, 12, T, 64)$ |
| 注意力得分 | $QK^\top$ | $(B, 12, T, T)$ |
| 输出 | $\alpha V$ | $(B, 12, T, 64)$ |
| 合头 | transpose + view | $(B, T, 768)$ |
| 最终投影 | $\times W_O$ | $(B, T, 768)$ |
口诀:“切头→转轴→算注意力→回轴→合头→过 $W_O$”。 [MLNLP]
6. 高效多头注意力的优化技巧 [MLNLP]
6.1 nn.Linear vs nn.Parameter
原书展示了两种实现:
- v1:用
nn.Parameter(torch.rand(d_in, d_out))手动定义权重 - v2:用
nn.Linear(d_in, d_out, bias=False)
v2 更好,原因:
nn.Linear使用了优化的权重初始化方案(Kaiming/Xavier),训练更稳定- 注意
nn.Linear内部以转置形式存储权重:weight.shape = (d_out, d_in),所以前向传播是x @ W.T而非x @ W - 计算效率更高,融合了更多底层优化 [Datawhale]
6.2 PyTorch 内置的融合注意力
生产环境中不要自己写注意力——使用 PyTorch 2.0+ 的融合内核:
import torch.nn.functional as F
# 自动选择 FlashAttention / Memory-Efficient Attention / 数学实现
output = F.scaled_dot_product_attention(Q, K, V, attn_mask=causal_mask)
优势:
- FlashAttention:通过分块计算避免实例化完整的 $T \times T$ 注意力矩阵,显存从 $O(T^2)$ 降到 $O(T)$
- IO 感知:减少 HBM(高带宽内存)读写次数,实际速度可提升 2-4×
- 数值稳定:内置 softmax 数值稳定性处理 [MLNLP]
6.3 DropOut 在注意力中的作用
在注意力权重上加 Dropout(通常 p=0.1~0.25),随机将部分权重置零。作用:
- 防止模型过度依赖少数特定位置
- 增加训练时的噪声,起到正则化效果
- 推理时自动关闭 [Datawhale]
7. 注意力机制的四种变体演变路线 [Datawhale]
本章实际实现了四种逐步增强的注意力变体,理解这个递进关系有助于整体把握:
| 变体 | 新增特性 | 解决的问题 |
|---|---|---|
| 简化自注意力 | 基本点积 + softmax | 演示核心思想 |
| 带权重的自注意力 | $W_Q, W_K, W_V$ | 引入可学习参数 |
| 因果注意力 | 上三角掩码 | 防止看到未来 |
| 多头注意力 | 并行多组注意力 | 捕捉多种关系模式 |
每一层都是在前一层基础上的最小增量修改,最终版本就是我们塞进 GPT 架构的模块。 [Datawhale]
8. 权重参数 vs 注意力权重 —— 别搞混 [Datawhale]
这个区分看似简单但经常被混淆:
- 权重参数(Weight Parameters):$W_Q, W_K, W_V$ 这些矩阵,是神经网络的静态学习参数,通过反向传播更新。它们定义了网络连接的基本属性。
- 注意力权重(Attention Weights):$\alpha_{ij}$,是动态的、随输入变化的值。不同的输入序列会产生不同的注意力权重分布。
前者是”学到的知识”,后者是”应用知识的方式”。 [Datawhale]
延伸阅读
- MLNLP-World 中文翻译项目还提供了 [高效多头注意力的不同实现变体对比] 和 [PyTorch Buffer 概念详解] 等补充材料 [MLNLP]
- 原书练习 3.1:尝试将
SelfAttention_v2的权重迁移到SelfAttention_v1,验证输出一致(提示:注意nn.Linear以转置形式存储权重)[Datawhale]