第3章:注意力机制

学习目标

  • 用一个最朴素的”加权平均”版本理解注意力的本质。
  • 推导 Q/K/V 的来源,掌握缩放点积注意力的完整公式。
  • 实现因果掩码,让自注意力不能”偷看未来”。
  • 把单头扩展到多头注意力。

3.1 从”上下文向量”开始的直觉

考虑一个序列 $x_1, x_2, \dots, x_T$,我们希望第 $t$ 个位置能”看一眼”前面的所有位置,融合出一个上下文向量 $z_t$。最朴素的做法:

$$z_t = \sum_{i=1}^{t} \alpha_{t,i}, x_i, \quad \sum_i \alpha_{t,i} = 1$$

权重 $\alpha_{t,i}$ 由”第 $t$ 个位置和第 $i$ 个位置的相似度”决定:

$$\alpha_{t,i} = \frac{\exp(\text{score}(x_t, x_i))}{\sum_{j} \exp(\text{score}(x_t, x_j))}$$

最朴素的相似度就是点积:$\text{score}(x_t, x_i) = x_t^\top x_i$。

3.2 引入 Query / Key / Value

直接用 $x$ 自己点自己有几个缺陷:

  1. 没有可学习参数,无法调整”注意力的视角”;
  2. “查询者""被查询者""信息载体”被混在同一个向量里。

解决方法:用三个线性变换,把同一个 $x_t$ 投影到三个不同的角色:

$$Q = X W_Q, \quad K = X W_K, \quad V = X W_V$$

其中 $X \in \mathbb{R}^{T \times d}$,$W_Q, W_K, W_V \in \mathbb{R}^{d \times d_k}$。直观理解:

  • Query (Q):当前 token “想问什么问题”;
  • Key (K):每个 token “能回答什么问题”;
  • Value (V):每个 token “真正要传出去的内容”。

注意力公式因此变成:

$$\text{Attn}(Q, K, V) = \mathrm{softmax}!\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V$$

为什么除以 $\sqrt{d_k}$? 当 $d_k$ 较大时,$Q K^\top$ 的方差也大,softmax 容易进入梯度极小的饱和区。除以 $\sqrt{d_k}$ 让 logits 的方差大致回到 1。

3.3 自己写一个最小自注意力

import torch
import torch.nn as nn

class SingleHeadAttention(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_q = nn.Linear(d_in, d_out, bias=False)
        self.W_k = nn.Linear(d_in, d_out, bias=False)
        self.W_v = nn.Linear(d_in, d_out, bias=False)

    def forward(self, x):                # x: (B, T, d_in)
        Q = self.W_q(x)                  # (B, T, d_out)
        K = self.W_k(x)
        V = self.W_v(x)
        scores = Q @ K.transpose(-2, -1) # (B, T, T)
        scores = scores / (K.size(-1) ** 0.5)
        weights = torch.softmax(scores, dim=-1)
        return weights @ V               # (B, T, d_out)

这一段不到 15 行,但已经是 GPT 的核心。

3.4 因果掩码:不能偷看未来

预训练任务是”根据前面预测下一个”。如果第 $t$ 个位置的注意力可以看到 $t+1, t+2, \dots$ 的 token,模型相当于直接从答案抄;训练损失会迅速逼近 0,但模型什么都没学到。

解决:在 softmax 之前把”未来位置”对应的 score 置为 $-\infty$,让 softmax 后的权重恰好为 0。

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout=0.0):
        super().__init__()
        self.W_q = nn.Linear(d_in, d_out, bias=False)
        self.W_k = nn.Linear(d_in, d_out, bias=False)
        self.W_v = nn.Linear(d_in, d_out, bias=False)
        self.dropout = nn.Dropout(dropout)
        # 上三角 mask(不含对角线),形状 (T, T)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1).bool()
        )

    def forward(self, x):
        B, T, _ = x.shape
        Q, K, V = self.W_q(x), self.W_k(x), self.W_v(x)
        scores = Q @ K.transpose(-2, -1) / (K.size(-1) ** 0.5)
        scores = scores.masked_fill(self.mask[:T, :T], float("-inf"))
        weights = torch.softmax(scores, dim=-1)
        weights = self.dropout(weights)
        return weights @ V

为什么 mask 用 register_buffer 而不是普通属性? 因为它需要随模型一起 to(device) 和保存,但又不参与反向传播。register_buffer 正是为这种张量准备的。

3.5 从单头到多头

单头注意力只能从一种”视角”看待序列关系。多头允许模型并行学习多种关系(语法、共指、语义等)。

做法:

  1. 把 $d_{model}$ 切成 $h$ 份,每份 $d_k = d_{model} / h$;
  2. 各头独立计算注意力;
  3. 把所有头的输出沿最后一维拼接,再过一个 W_o 线性层。
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, context_length, dropout=0.0):
        super().__init__()
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1).bool()
        )

    def forward(self, x):
        B, T, C = x.shape
        H, D = self.num_heads, self.head_dim

        # (B, T, C) -> (B, T, H, D) -> (B, H, T, D)
        Q = self.W_q(x).view(B, T, H, D).transpose(1, 2)
        K = self.W_k(x).view(B, T, H, D).transpose(1, 2)
        V = self.W_v(x).view(B, T, H, D).transpose(1, 2)

        scores = Q @ K.transpose(-2, -1) / (D ** 0.5)        # (B, H, T, T)
        scores = scores.masked_fill(self.mask[:T, :T], float("-inf"))
        weights = self.dropout(torch.softmax(scores, dim=-1))

        out = weights @ V                                    # (B, H, T, D)
        out = out.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
        return self.W_o(out)

形状变化口诀: “切头—转轴—算注意力—回轴—合头—过 W_o”。

3.6 计算与显存代价

设 batch=B, 序列=T, 维度=C:

  • 时间复杂度:$O(B \cdot T^2 \cdot C)$,主要在 $QK^\top$;
  • 注意力矩阵显存:$O(B \cdot H \cdot T^2)$。

T 翻倍,显存涨 4 倍——这就是为什么”上下文长度”是 LLM 工程里的硬骨头,也是 FlashAttention、滑窗注意力等优化的发力点。本课程实现的是教学版,生产环境请使用 torch.nn.functional.scaled_dot_product_attention,它会自动选择 FlashAttention 内核。

检查清单

  • 我能在不看代码的情况下写出 Attn(Q,K,V) 的公式。
  • 我能解释 mask 为什么放在 softmax 之前。
  • 我能讲清楚多头注意力中张量形状的每一步变化。

练习题

  1. 不使用 mask,直接用 Q @ K.T 计算后取下三角,再做 softmax。结果和”先 mask 后 softmax”是否相同?为什么?
  2. num_heads 设为 1,验证 MultiHeadAttentionCausalAttention 输出形状一致。
  3. 把 mask 改成”只能看前 8 个位置”的滑窗形式,写出新的 mask 构造代码(提示:torch.triu + torch.tril 组合)。

📖 第3章补充材料 → — 注意力直觉理解、自注意力数学推导、高效多头注意力


← 上一章 · 返回目录 · 下一章 · 从零实现 GPT →