Transformer 与 Self-Attention
Transformer 是一种序列建模架构(2017 年提出),它把“序列之间的依赖”主要交给注意力机制来建模,而不是依赖 RNN 的时间步递推。 它之所以重要,核心在两点: 一句话直觉: 经典 Transformer(seq2seq)包含两大块: 每一块通常由 $N$ 层堆叠(论文里常见 $N=6$)。各层结构相同,但参数不共享。 一个 Encoder Layer 通常包含: 一个 Decoder Layer 通常包含: 同一个输入 $X$ 会被线性映射出 $Q,K,V$: $$ 其中 $X \in \mathbb{R}^{L\times d_{model}}$(长度 $L$,隐藏维 $d_{model}$)。 注意力权重来自相似度(点积)并做缩放: $$ 为什么要除以 $\sqrt{d_k}$:当维度较大时点积数值更容易变大,softmax 会更“尖”,梯度更不稳定;缩放能让训练更稳。 对每个位置 $i$: 单头注意力只能在一个“子空间”里做匹配。多头注意力的做法是: $$ 直觉:不同的头可以分别学“指代关系”“语法依赖”“长距离对齐”等不同模式。 自注意力本身对输入顺序不敏感(你把 token 乱序,注意力计算形式不变)。因此需要显式注入位置信息。 常见做法:把位置向量 $PE(pos)$ 与词向量相加: $$ 论文中使用的正弦/余弦位置编码: $$ 补充:很多实现也会用可学习位置编码(learnable embeddings),同样有效。 Transformer 里几乎每个子层都采用: $$ FFN 是逐位置(position-wise)的两层 MLP: $$ 它不在 token 间交互(交互在 attention 里做),但能增强非线性表达能力。 批处理时序列会 padding 到同一长度。padding token 不应该被关注,因此要把这些位置的 attention logits 设为 $-\infty$(实现里通常是一个足够小的负数)。 自回归生成时,第 $t$ 个位置不能看见 $t$ 之后的位置,所以要加上一个上三角 mask。 下面用一个小矩阵例子把公式跑通(重点是理解矩阵形状与步骤)。 输出: 说明:真实模型里 $W_Q,W_K,W_V$ 是可训练参数,这里为了可复现,用手写的小矩阵。 这里用最简单的 $QK^T$(完整版本还要除以 $\sqrt{d_k}$): 为了更直观看清“加权求和”,我们也可以做一个近似版本(教学用): 标准矩阵写法是: 如果你想看“每个 value 被乘了多少”,可以像下面这样拆开(便于教学观察):Transformer 与 Self-Attention (整理版)
目标:用“概念 → 公式 → 代码 → 练习”的方式,把 Transformer 的核心机制讲清楚。
目录
1. Transformer 是什么
对于序列中的每个位置,让模型学会“我应该关注哪些位置,以及关注多少”。
2. 编码器/解码器整体结构
3. Self-Attention(自注意力)
3.1 Q/K/V 的含义(非常实用的直觉)
Q = XW_Q,\quad K = XW_K,\quad V = XW_V
$$3.2 Scaled Dot-Product Attention 公式
\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$3.3 输出在做什么
4. Multi-Head Attention(多头注意力)
\mathrm{MHA}(X)=\mathrm{Concat}(Z_1,\dots,Z_h)W_O
$$5. 位置编码(Positional Encoding)
X' = X + PE
$$
PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right),\quad
PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)
$$6. Add & Norm(残差 + LayerNorm)
Y = \mathrm{LayerNorm}(X + \mathrm{Sublayer}(X))
$$7. FFN(前馈网络)
\mathrm{FFN}(x)=\sigma(xW_1+b_1)W_2+b_2
$$8. Mask:Padding Mask 与 Causal Mask
8.1 Padding Mask
8.2 Causal Mask(Decoder 的“不能看未来”)
9. 用 PyTorch 手写一次 Self-Attention
9.1 准备输入
import torch
from torch.nn.functional import softmax
x = torch.tensor(
[
[1, 0, 1, 0], # token 1 embedding
[0, 2, 0, 2], # token 2 embedding
[1, 1, 1, 1], # token 3 embedding
],
dtype=torch.float32,
)
print(x)tensor([[1., 0., 1., 0.],
[0., 2., 0., 2.],
[1., 1., 1., 1.]])9.2 构造 Q/K/V 映射矩阵
w_key = torch.tensor(
[
[0, 0, 1],
[1, 1, 0],
[0, 1, 0],
[1, 1, 0],
],
dtype=torch.float32,
)
w_query = torch.tensor(
[
[1, 0, 1],
[1, 0, 0],
[0, 0, 1],
[0, 1, 1],
],
dtype=torch.float32,
)
w_value = torch.tensor(
[
[0, 2, 0],
[0, 3, 0],
[1, 0, 3],
[1, 1, 0],
],
dtype=torch.float32,
)
print("w_key\n", w_key)
print("w_query\n", w_query)
print("w_value\n", w_value)9.3 计算 K/Q/V
keys = x @ w_key
queries = x @ w_query
values = x @ w_value
print("keys\n", keys)
print("queries\n", queries)
print("values\n", values)9.4 计算注意力分数(logits)
attn_logits = queries @ keys.T
print(attn_logits)9.5 softmax 得到权重
attn_weights = softmax(attn_logits, dim=-1)
print(attn_weights)attn_weights_simple = torch.tensor(
[
[0.0, 0.5, 0.5],
[0.0, 1.0, 0.0],
[0.0, 0.9, 0.1],
],
dtype=torch.float32,
)9.6 加权求和得到输出
output = attn_weights @ values
print(output)weighted_values = values[:, None] * attn_weights_simple.T[:, :, None]
print(weighted_values)
print("sum over tokens ->", weighted_values.sum(dim=0))10. 优缺点与常见坑
优点
常见坑
11. 小练习
attn_logits 替换为 attn_logits / (dk ** 0.5),其中 dk = keys.size(-1)。-1e9),观察输出变化。values 改大一倍,看看输出是否也线性变大(应该会)。12. 延伸阅读
torch.nn.MultiheadAttention 文档与源码(理解工程实现细节)