Transformer 与 Self-Attention (整理版)

目标:用“概念 → 公式 → 代码 → 练习”的方式,把 Transformer 的核心机制讲清楚。

目录


1. Transformer 是什么

Transformer 是一种序列建模架构(2017 年提出),它把“序列之间的依赖”主要交给注意力机制来建模,而不是依赖 RNN 的时间步递推。

它之所以重要,核心在两点:

  • 并行性:训练时可以对整段序列并行计算注意力(比 RNN 更容易吃满 GPU/TPU)。
  • 长程依赖:任意两个 token 之间的交互路径更短(自注意力是一次“全连接式”交互)。

一句话直觉:

对于序列中的每个位置,让模型学会“我应该关注哪些位置,以及关注多少”。

2. 编码器/解码器整体结构

经典 Transformer(seq2seq)包含两大块:

  • Encoder(编码器):把输入序列编码为一组上下文表示。
  • Decoder(解码器):在生成第 $t$ 个输出 token 时,只能看见 $t$ 之前已生成的内容,并结合 Encoder 输出进行交互(cross-attention)。

每一块通常由 $N$ 层堆叠(论文里常见 $N=6$)。各层结构相同,但参数不共享。

一个 Encoder Layer 通常包含:

  1. Multi-Head Self-Attention
  2. Add & Norm
  3. Feed-Forward Network (FFN)
  4. Add & Norm

一个 Decoder Layer 通常包含:

  1. Masked Multi-Head Self-Attention(遮住未来)
  2. Add & Norm
  3. Multi-Head Cross-Attention(Q 来自 decoder,K/V 来自 encoder)
  4. Add & Norm
  5. FFN
  6. Add & Norm

3. Self-Attention(自注意力)

3.1 Q/K/V 的含义(非常实用的直觉)

  • Query(Q):我“想找什么信息”。
  • Key(K):我“是什么信息的索引/标签”。
  • Value(V):我“真正携带的内容”。

同一个输入 $X$ 会被线性映射出 $Q,K,V$:

$$
Q = XW_Q,\quad K = XW_K,\quad V = XW_V
$$

其中 $X \in \mathbb{R}^{L\times d_{model}}$(长度 $L$,隐藏维 $d_{model}$)。

3.2 Scaled Dot-Product Attention 公式

注意力权重来自相似度(点积)并做缩放:

$$
\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

为什么要除以 $\sqrt{d_k}$:当维度较大时点积数值更容易变大,softmax 会更“尖”,梯度更不稳定;缩放能让训练更稳。

3.3 输出在做什么

对每个位置 $i$:

  • 先算它对所有位置 $j$ 的相关性分数 $s_{ij}$
  • softmax 得到权重 $a_{ij}$
  • 用权重对所有 value 做加权求和,得到该位置的新表示

4. Multi-Head Attention(多头注意力)

单头注意力只能在一个“子空间”里做匹配。多头注意力的做法是:

  1. 用 $h$ 组不同的线性映射得到 $Q_i,K_i,V_i$
  2. 每个头独立算 attention 得到 $Z_i$
  3. 把各头拼接后再做一次线性变换

$$
\mathrm{MHA}(X)=\mathrm{Concat}(Z_1,\dots,Z_h)W_O
$$

直觉:不同的头可以分别学“指代关系”“语法依赖”“长距离对齐”等不同模式。


5. 位置编码(Positional Encoding)

自注意力本身对输入顺序不敏感(你把 token 乱序,注意力计算形式不变)。因此需要显式注入位置信息。

常见做法:把位置向量 $PE(pos)$ 与词向量相加:

$$
X' = X + PE
$$

论文中使用的正弦/余弦位置编码:

$$
PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right),\quad
PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)
$$

补充:很多实现也会用可学习位置编码(learnable embeddings),同样有效。


6. Add & Norm(残差 + LayerNorm)

Transformer 里几乎每个子层都采用:

$$
Y = \mathrm{LayerNorm}(X + \mathrm{Sublayer}(X))
$$

  • 残差连接:让信息与梯度更容易流动,深层训练更稳定。
  • LayerNorm:对单样本特征维做归一化,NLP 中通常比 BatchNorm 更合适。

7. FFN(前馈网络)

FFN 是逐位置(position-wise)的两层 MLP:

$$
\mathrm{FFN}(x)=\sigma(xW_1+b_1)W_2+b_2
$$

它不在 token 间交互(交互在 attention 里做),但能增强非线性表达能力。


8. Mask:Padding Mask 与 Causal Mask

8.1 Padding Mask

批处理时序列会 padding 到同一长度。padding token 不应该被关注,因此要把这些位置的 attention logits 设为 $-\infty$(实现里通常是一个足够小的负数)。

8.2 Causal Mask(Decoder 的“不能看未来”)

自回归生成时,第 $t$ 个位置不能看见 $t$ 之后的位置,所以要加上一个上三角 mask。


9. 用 PyTorch 手写一次 Self-Attention

下面用一个小矩阵例子把公式跑通(重点是理解矩阵形状与步骤)。

9.1 准备输入

import torch
from torch.nn.functional import softmax

x = torch.tensor(
  [
    [1, 0, 1, 0],  # token 1 embedding
    [0, 2, 0, 2],  # token 2 embedding
    [1, 1, 1, 1],  # token 3 embedding
  ],
  dtype=torch.float32,
)

print(x)

输出:

tensor([[1., 0., 1., 0.],
    [0., 2., 0., 2.],
    [1., 1., 1., 1.]])

9.2 构造 Q/K/V 映射矩阵

说明:真实模型里 $W_Q,W_K,W_V$ 是可训练参数,这里为了可复现,用手写的小矩阵。

w_key = torch.tensor(
  [
    [0, 0, 1],
    [1, 1, 0],
    [0, 1, 0],
    [1, 1, 0],
  ],
  dtype=torch.float32,
)
w_query = torch.tensor(
  [
    [1, 0, 1],
    [1, 0, 0],
    [0, 0, 1],
    [0, 1, 1],
  ],
  dtype=torch.float32,
)
w_value = torch.tensor(
  [
    [0, 2, 0],
    [0, 3, 0],
    [1, 0, 3],
    [1, 1, 0],
  ],
  dtype=torch.float32,
)

print("w_key\n", w_key)
print("w_query\n", w_query)
print("w_value\n", w_value)

9.3 计算 K/Q/V

keys = x @ w_key
queries = x @ w_query
values = x @ w_value

print("keys\n", keys)
print("queries\n", queries)
print("values\n", values)

9.4 计算注意力分数(logits)

这里用最简单的 $QK^T$(完整版本还要除以 $\sqrt{d_k}$):

attn_logits = queries @ keys.T
print(attn_logits)

9.5 softmax 得到权重

attn_weights = softmax(attn_logits, dim=-1)
print(attn_weights)

为了更直观看清“加权求和”,我们也可以做一个近似版本(教学用):

attn_weights_simple = torch.tensor(
  [
    [0.0, 0.5, 0.5],
    [0.0, 1.0, 0.0],
    [0.0, 0.9, 0.1],
  ],
  dtype=torch.float32,
)

9.6 加权求和得到输出

标准矩阵写法是:

output = attn_weights @ values
print(output)

如果你想看“每个 value 被乘了多少”,可以像下面这样拆开(便于教学观察):

weighted_values = values[:, None] * attn_weights_simple.T[:, :, None]
print(weighted_values)
print("sum over tokens ->", weighted_values.sum(dim=0))

10. 优缺点与常见坑

优点

  • 效果强:尤其在大数据与大模型规模下。
  • 并行友好:训练吞吐高。
  • 长距离依赖更容易学到。

常见坑

  • mask 忘了加:padding token 参与注意力会污染表示;decoder 不加 causal mask 会“偷看答案”。
  • shape 搞混:批次维、头数维、序列长度维容易写错。
  • softmax 维度写错:通常要对最后一维(key 维/序列维)做 softmax。

11. 小练习

  1. 把第 9 节的示例改成带缩放:将 attn_logits 替换为 attn_logits / (dk ** 0.5),其中 dk = keys.size(-1)
  2. 写一个 causal mask(上三角),把未来位置 logits 置为一个很小的负数(如 -1e9),观察输出变化。
  3. values 改大一倍,看看输出是否也线性变大(应该会)。

12. 延伸阅读

  • 《Attention Is All You Need》:Transformer 原论文(arXiv)
  • The Illustrated Transformer(图解 Transformer,直觉非常好)
  • PyTorch 官方 torch.nn.MultiheadAttention 文档与源码(理解工程实现细节)

作者:Smoothcloud润云

标签: none

添加新评论