第3章:注意力机制
学习目标
- 用一个最朴素的”加权平均”版本理解注意力的本质。
- 推导 Q/K/V 的来源,掌握缩放点积注意力的完整公式。
- 实现因果掩码,让自注意力不能”偷看未来”。
- 把单头扩展到多头注意力。
3.1 从”上下文向量”开始的直觉
考虑一个序列 $x_1, x_2, \dots, x_T$,我们希望第 $t$ 个位置能”看一眼”前面的所有位置,融合出一个上下文向量 $z_t$。最朴素的做法:
$$z_t = \sum_{i=1}^{t} \alpha_{t,i}, x_i, \quad \sum_i \alpha_{t,i} = 1$$
权重 $\alpha_{t,i}$ 由”第 $t$ 个位置和第 $i$ 个位置的相似度”决定:
$$\alpha_{t,i} = \frac{\exp(\text{score}(x_t, x_i))}{\sum_{j} \exp(\text{score}(x_t, x_j))}$$
最朴素的相似度就是点积:$\text{score}(x_t, x_i) = x_t^\top x_i$。
3.2 引入 Query / Key / Value
直接用 $x$ 自己点自己有几个缺陷:
- 没有可学习参数,无法调整”注意力的视角”;
- “查询者""被查询者""信息载体”被混在同一个向量里。
解决方法:用三个线性变换,把同一个 $x_t$ 投影到三个不同的角色:
$$Q = X W_Q, \quad K = X W_K, \quad V = X W_V$$
其中 $X \in \mathbb{R}^{T \times d}$,$W_Q, W_K, W_V \in \mathbb{R}^{d \times d_k}$。直观理解:
- Query (Q):当前 token “想问什么问题”;
- Key (K):每个 token “能回答什么问题”;
- Value (V):每个 token “真正要传出去的内容”。
注意力公式因此变成:
$$\text{Attn}(Q, K, V) = \mathrm{softmax}!\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V$$
为什么除以 $\sqrt{d_k}$? 当 $d_k$ 较大时,$Q K^\top$ 的方差也大,softmax 容易进入梯度极小的饱和区。除以 $\sqrt{d_k}$ 让 logits 的方差大致回到 1。
3.3 自己写一个最小自注意力
import torch
import torch.nn as nn
class SingleHeadAttention(nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
self.W_q = nn.Linear(d_in, d_out, bias=False)
self.W_k = nn.Linear(d_in, d_out, bias=False)
self.W_v = nn.Linear(d_in, d_out, bias=False)
def forward(self, x): # x: (B, T, d_in)
Q = self.W_q(x) # (B, T, d_out)
K = self.W_k(x)
V = self.W_v(x)
scores = Q @ K.transpose(-2, -1) # (B, T, T)
scores = scores / (K.size(-1) ** 0.5)
weights = torch.softmax(scores, dim=-1)
return weights @ V # (B, T, d_out)
这一段不到 15 行,但已经是 GPT 的核心。
3.4 因果掩码:不能偷看未来
预训练任务是”根据前面预测下一个”。如果第 $t$ 个位置的注意力可以看到 $t+1, t+2, \dots$ 的 token,模型相当于直接从答案抄;训练损失会迅速逼近 0,但模型什么都没学到。
解决:在 softmax 之前把”未来位置”对应的 score 置为 $-\infty$,让 softmax 后的权重恰好为 0。
class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout=0.0):
super().__init__()
self.W_q = nn.Linear(d_in, d_out, bias=False)
self.W_k = nn.Linear(d_in, d_out, bias=False)
self.W_v = nn.Linear(d_in, d_out, bias=False)
self.dropout = nn.Dropout(dropout)
# 上三角 mask(不含对角线),形状 (T, T)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length), diagonal=1).bool()
)
def forward(self, x):
B, T, _ = x.shape
Q, K, V = self.W_q(x), self.W_k(x), self.W_v(x)
scores = Q @ K.transpose(-2, -1) / (K.size(-1) ** 0.5)
scores = scores.masked_fill(self.mask[:T, :T], float("-inf"))
weights = torch.softmax(scores, dim=-1)
weights = self.dropout(weights)
return weights @ V
为什么 mask 用 register_buffer 而不是普通属性?
因为它需要随模型一起 to(device) 和保存,但又不参与反向传播。register_buffer 正是为这种张量准备的。
3.5 从单头到多头
单头注意力只能从一种”视角”看待序列关系。多头允许模型并行学习多种关系(语法、共指、语义等)。
做法:
- 把 $d_{model}$ 切成 $h$ 份,每份 $d_k = d_{model} / h$;
- 各头独立计算注意力;
- 把所有头的输出沿最后一维拼接,再过一个
W_o线性层。
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, context_length, dropout=0.0):
super().__init__()
assert d_model % num_heads == 0
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length), diagonal=1).bool()
)
def forward(self, x):
B, T, C = x.shape
H, D = self.num_heads, self.head_dim
# (B, T, C) -> (B, T, H, D) -> (B, H, T, D)
Q = self.W_q(x).view(B, T, H, D).transpose(1, 2)
K = self.W_k(x).view(B, T, H, D).transpose(1, 2)
V = self.W_v(x).view(B, T, H, D).transpose(1, 2)
scores = Q @ K.transpose(-2, -1) / (D ** 0.5) # (B, H, T, T)
scores = scores.masked_fill(self.mask[:T, :T], float("-inf"))
weights = self.dropout(torch.softmax(scores, dim=-1))
out = weights @ V # (B, H, T, D)
out = out.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
return self.W_o(out)
形状变化口诀: “切头—转轴—算注意力—回轴—合头—过 W_o”。
3.6 计算与显存代价
设 batch=B, 序列=T, 维度=C:
- 时间复杂度:$O(B \cdot T^2 \cdot C)$,主要在 $QK^\top$;
- 注意力矩阵显存:$O(B \cdot H \cdot T^2)$。
T 翻倍,显存涨 4 倍——这就是为什么”上下文长度”是 LLM 工程里的硬骨头,也是 FlashAttention、滑窗注意力等优化的发力点。本课程实现的是教学版,生产环境请使用 torch.nn.functional.scaled_dot_product_attention,它会自动选择 FlashAttention 内核。
检查清单
- 我能在不看代码的情况下写出
Attn(Q,K,V)的公式。 - 我能解释 mask 为什么放在 softmax 之前。
- 我能讲清楚多头注意力中张量形状的每一步变化。
练习题
- 不使用 mask,直接用
Q @ K.T计算后取下三角,再做 softmax。结果和”先 mask 后 softmax”是否相同?为什么? - 把
num_heads设为 1,验证MultiHeadAttention与CausalAttention输出形状一致。 - 把 mask 改成”只能看前 8 个位置”的滑窗形式,写出新的 mask 构造代码(提示:
torch.triu+torch.tril组合)。
📖 第3章补充材料 → — 注意力直觉理解、自注意力数学推导、高效多头注意力