标签 FLOPs 下的文章

编者按: 为什么在强化学习(RL)中,模型往往需要消耗比有监督学习多出数个数量级的计算资源,却只能换来看似微薄的性能提升,且常常陷入训练不稳定的泥潭?

本文从信息论角度出发,对比了有监督学习与强化学习在单位样本中可获取信息量的根本差异:前者通过明确的正确标签直接提供高信息密度的学习信号,而后者仅依赖二元的成功/失败反馈,其信息熵在通过率极低或极高时趋近于零。作者进一步指出,只有当模型的“通过率”处于约 50% 的“金发姑娘区”时,RL 才能高效学习,而这通常只出现在训练末期。此外,文章还剖析了 RL 中梯度估计方差巨大、容易被简单启发式策略主导、难以培养通用推理能力等深层问题,并反思了人类学习机制与当前 model-free RL 的本质差距。

这篇文章提醒我们:若想让强化学习真正释放其潜力,不能仅靠堆算力,而必须重新思考如何设计更密集、更结构化的反馈机制 —— 否则,我们可能只是在用极其昂贵的方式,重复确认一个早已写在预训练权重里的答案。

作者 | Dwarkesh Patel

编译 | 岳扬

最近,人们[1]一直在讨论[2]:在强化学习(RL)中生成单个样本所需的计算量(FLOPs)远高于有监督学习(supervised learning)。在预训练阶段,模型对每一个用于训练的 token 都能立即获得一个学习信号;而在 RL 中,必须展开一整条长达数万 tokens 的推理思维链,才能在最后得到一个奖励信号(例如,我写的代码单元测试是否通过?这道数学题的答案是否正确?等等)。

但这只是问题的一半。这里有一种简单的方法可以比较强化学习与有监督学习的学习效率:

Bits/FLOP = Samples/Flop × Bits/Sample

我还没听到有人讨论我们公式中的这一项:Bits/Sample(每个样本包含多少有用信息)。而且在训练的大部分阶段,强化学习的每一个样本所包含的“有效学习信息量”比有监督学习要低得多。

01 用大白话来说

在有监督学习(也就是预训练)中,模型只是在疯狂吸收信息(bits)。每一个 token 都像是一条线索,它不仅能帮你理解语言本身的构造,还能让你窥见创造这段语言的思维过程,以及那个思维所感知的现实世界。在训练初期,当你用一个完全随机初始化的模型时,你对这些内容都处于最大程度的不确定状态。因此,每个 token 都会让你“恍然大悟”。而且你会立刻得到一个精确的信号,知道自己对正确答案的预测错得多离谱,以及需要调整哪些参数来减少错误。

假设你从一个随机初始化的模型开始,并启动训练。如果你使用有监督学习对 “The sky is” 这个短语做 next-token-prediction,那么训练循环会这样工作:“正确答案其实是 ‘blue’。你预测 ‘blue’ 的概率只有 0.001%。现在,请大幅加强那些本该指向 ‘blue’ 的连接权重。好了,下一个 token。”

而在使用策略梯度(policy gradient)的强化学习中,你会增加所有回答正确的轨迹的权重,并降低所有回答错误的轨迹的权重。但问题是,一个还没怎么学会东西的模型,几乎不可能凭运气就答对。

如果你用 RL 来做“The sky is”的 next-token-prediction,训练循环大概会是这样:“好吧,‘halcyon’ 是错的,别再做导致输出‘halcyon’的操作了…… 好吧,‘serendipity’ 也是错的……” 然后就这样反复试错,猜错的次数差不多得有词汇表总量那么多(约 10 万次)。

02 详细分析

让我们思考一下:随着通过率(p)的变化,每个样本所能获得的最大信息量(bits/sample)会如何变化。这里的“通过率”指的是你给出正确答案的概率。 为简化起见,我们假设答案长度只有一个词元。那么,对于一个完全未经训练的模型,其通过率仅仅是 1/(词汇表大小)。

在有监督学习中,每个样本都会明确告诉你正确标签是什么。你学到的新信息量,取决于你看到正确答案时有多“惊讶” —— 你的通过率越低(即正确答案的先验概率越小),你从这个标签中学到的东西就越多。信息熵的基本公式告诉我们:在有监督学习中,你从每个样本中最多可以学到 -log(p) bits 的信息。

而在强化学习中,你只会被告知答案是否正确。你能从中提取的信息量,受限于你对这个二元结果(对/错)的不确定性。如果你几乎总是通过(p ≈ 1)或几乎总是失败(p ≈ 0),那么每次试验都很难让你感到意外。当通过的概率像抛硬币一样时(p ≈ 0.5),你学到的东西最多。 对于一个二元随机变量,其信息量的上限由熵公式给出:在 RL 中,你从每个样本中最多能学到 Entropy(p) = -p log(p) - (1-p) log(1-p)1 bits 的信息。

好,我们来画图。

看起来还不算太糟。是的,在通过率前 50% 的范围内,预训练明显更好,但在后 50% 的范围内,强化学习表现更佳。然而,这张图极具误导性。根据缩放定律(scaling laws)中的幂律关系,每当你想把“通过率”(pass rate)提升一个数量级,你都需要投入大致相同量级的计算资源。 如果你花了 X FLOPs 将通过率从 1/100,000 提升到 1/10,000,那么你也需要 X FLOPs 才能将通过率从 1/10,000 提升到 1/1,000。因此,我们应该使用对数刻度来表示通过率 —— 以便使 X 轴的每一单位增量对应于相同数量的计算开销(FLOPs)。

这张图看起来真令人沮丧。强化学习在样本信息密度上与预训练相当的区域,仅仅是训练末期的一小段,而且此时模型本身已经相当不错了。

再次强调,这一问题完全独立于另一个观点:即从强化学习中获取单个样本(也就是在得到任何信号前必须完整展开一整条推理轨迹)可能需要耗费高出数百万倍的计算量。

03 方差(variance)让实际情况甚至比这更糟

训练初期的强化学习,实际情况其实比上面描述的更为严峻。当通过率很低时,对梯度的估计会变得极其混乱且难以预测。 要么在当前 batch 生成的样本中,根本就没有采样到正确答案,在这种情况下,几乎得不到任何有用的学习信号。要么碰巧采样到了一次,然后就会得到一个巨大的梯度峰值。模型的训练过程会被剧烈地、不规则地“拉扯”(梯度忽大忽小、方向混乱),如果要追求高效、稳定的训练,这样是非常糟糕的。2

有趣的是,预训练的问题恰好相反,方差(variance)在训练末期会变得非常高。随着预训练的推进,你会逐渐耗尽那些可约损失(reducible loss,即模型实际能从数据中学到的东西)。剩下的主要都是不可约损失(irreducible loss),不可约损失指的是网络文本数据固有的不可预测性。

提示词 “Bob’s favorite color is” 应该怎么结尾?这完全取决于 Bob 是谁。对于这种问题,并不存在什么标准正确答案能让你的超级智能模型通过训练达到很高的预测准确率。但是,模型仍然会根据某人在网上留下的随机答案,获得梯度更新(gradient update)。而这种噪音,会淹没当前 batch 中少数几个真正可学习的词元为我们提供的真实信号。我不知道这是否准确,但预训练阶段末期出现的这种方差激增,似乎与为什么在预训练过程中需要增大 batch sizes 有关。

04 进入 RL 的“金发姑娘区”(Goldilocks zone)

如果 RL 在通过率远高于 1% 时效果最佳,那么这就引出了一个问题:我们该如何设计 RL 训练过程,才能让模型进入并维持在这个高效学习的状态中?

例如,在进行强化学习(RL)时,我们可以通过“预训练更多的数据”和“增加推理时的计算量(比如让模型想得更久)”这两种方式,来让模型变得更聪明、回答得更准确,提高模型的“通过率”,从而让每个样本带来更多的有效信息(bits)。

有观点指出,课程学习(curriculum learning)在预训练中作用不大[3],但在 RL 中却常常不可或缺[4]。这完全说得通 —— 因为 RL 只有在通过率处于这个“金发姑娘区”时,每个样本才能带来有意义的信息量。因此,为了训练效果好,你必须精心安排学习内容的顺序,要保证问题的难度是随着模型能力的提升而同步加难的,不要一下子给太难的题,也不要一直做太简单的题。

作者提出的“通过率”理论可以很好地解释为什么“自我对弈”(像 AlphaGo 那样自己跟自己下棋)在强化学习历史上特别管用。因为当你跟一个水平旗鼓相当的对手比赛时,你赢的概率大约就是 50%。在这个理论中,50%是一个最佳状态,意味着每次比赛结果(输或赢)带给你的信息量是最大的,能让你学得最快。

但自我对弈并不是唯一能让训练过程中保持高通过率的方法。我们还可以设计出一种“proxy evaluation”机制,这种机制能提供更密集的反馈信息。这里的“密集”具体指以下两种情况之一:

1)Samples/FLOP 密度:通过“proxy evaluation”方法,我们可以在一个强化学习回合刚开始不久时就估算出最终的奖励,而不必真的把整个过程跑完,从而省去了后续的大量计算消耗。这种机制其实就是所谓的“价值函数”。

2)Bits/Sample 密度:我们可以设计一个比最终目标更易达成的 proxy objectives 来指导模型。我能想到的最简单例子是过程奖励模型(process-reward model),它会这样说:“嘿,这次生成的答案虽然错了,但我看得出来,它一开始的推理方向是对的。那我们就给这些早期的 token 增加一点权重。”

Deepseek R1[5] 论文的 4.2 节讨论并解释了,为什么直到现在,要为大语言模型开发出像这样好用的 proxy objectives 依然是一件很难的事情。

05 信息量虽少,但价值高

虽然在强化学习中,每单位计算量(FLOP)学到的 bits 确实少得多,但这些 bits 却非常重要,它们与预训练中获得的 bits 信息不能简单地相提并论。 这其中主要有两个关键原因:

  • 预训练就像是让模型把互联网上现有的数据全记下来,但这种知识与“如何完成具有经济价值的任务”只有部分且间接的关联;而强化学习则是直接教模型怎么去解决那些真正有用、能产生价值的实际问题。
  • 即使预训练语料中包含了完成某项任务的“操作说明”(比如教程、具体步骤或答案),它也缺少一种关键的东西 —— “思维轨迹”(thinking trace)。也就是说,数据里没有展示模型犯错时是怎么自我纠正的,也没有展示如何利用模型独特的、非人类的方式去组合技能来解决问题。而这些深层的思考痕迹,正是强化学习能提供的东西。

反驳的观点认为,虽然这些信息很有价值,但它们只在一个非常窄的通过率范围内(比如模型已经挺聪明了,但还没完全学会的时候)才能被获取。之所以要强调这一点,是因为在训练的大部分时间里,模型的通过率都极低(接近0),在对数尺度上看,这些低通过率的阶段占据了很大的比重,这意味着真正能高效学习的窗口期其实很短。

现在我们就能理解那些关于 RLHF/RL 仅能激发预训练模型中已有的潜在能力的说法了[6]。事实当然如此。如果预训练模型初始的通过率不够高,那么强化学习的 bits/sample 就会低得可怜,从而根本无法进行有效学习。 围棋对战中的“第 37 手”是一个非常著名的案例,它证明了强化学习确实能教给模型一种全新的、前所未有的策略。值得注意的是,AlphaGo 是通过自我对弈训练出来的(见上文关于自我对弈如何提高通过率的论述),而且以当时的标准来看[7],其计算消耗之巨令人吃惊。

06 强化学习的不均衡

人们指出,从经验上看,RLVR(强化学习 + 可验证奖励)实际上只是让模型将某种思维模式与特定问题类型关联起来,而并未真正培养出一种更通用的策略 —— 比如先退一步,再仔细思考最佳解法。

仔细想想。怎么会有模型在国际编程竞赛中达到世界顶尖水平,却同时在代码库中留下了大量本可预见的 Bug 和技术债务?

这种奇怪的不均衡该如何解释?也许 RLVR 无法区分一条成功的推理轨迹到底是模型通过某种通用的推理能力(举一反三)做出来的,还是仅仅靠死记硬背某种特定的解题模板(“看到这个形状就用这个套路”)做出来的。因为它没法区分这两种过程,所以模型可能学会了后者(简单的套路),而不是前者(通用的能力)。

当你使用策略梯度(policy gradient)进行 rollout(即让模型生成完整的行为序列)时,那种更复杂、更具泛化能力的策略几乎不可能被采样到;而简单的启发式策略却很容易被采样到,并随着训练不断被强化,出现频率越来越高,最终完全主导模型的行为(即达到“固定”状态)。与此同时,真正的通用策略则越来越难以被观察到,逐渐从训练过程中消失。

那么问题来了,我们该如何搭建一座“短桥”,把简单的启发式解法,和那种更复杂、更具泛化能力的通用策略连接起来?而且,这座桥会不会随着任务时间跨度(time horizons)自然拉长而自动出现 —— 从而迫使模型发展出真正的泛化能力?

我担心的是,那种“先退一步、基于对世界的理解做出明智判断”的通用策略,即使在更长周期的任务中,也依然很难通过“可验证的奖励”(verifiable rewards)被有效识别和强化。因此,要解决这种不均衡问题,不能只靠扩大 RLVR 的规模,而必须设计更鲁棒的训练方法。

07 人类的学习方式

本节我们讨论的只是 model-free RL —— 也就是仅从一个强化学习周期结束时的二元结果(成功/失败)中获得的信息量(bits/sample)。但显然,人类的学习效率远高于此。想想假如有一位连续创业者,我们会说她拥有大量来之不易的智慧和经验。而这些学习成果中,极少部分真正来自上一次创业的“one bit”结果(即创业成功与否)。

目前还不清楚,在机器学习中,人类这种从经验中学习的方式对应的是什么机制。 显然,我们的观察与反思会不断更新我们的世界模型(world model) —— 而且这种更新并不依赖于最终结果是成功还是失败。这在人类学习过程中起着非常重要的作用。

也许我们不该只是想着“如何把 model-free RL 的通过率调到 50% 左右,因为这样做仅仅是试图从一个单一的“成功/失败”结果中,挤出那么一点点微薄的信息。也许我们应该转换思路,去研究人类是如何从环境中获取海量信息的。人类并不像现在的机器那样,只盯着最终的结果(成功或失败),而是能从过程、观察和反思中吸收大量的经验和教训。

1 这个公式的意思是:从一个二元结果中学到的信息量 =p(样本正确) × (样本正确时获得的信息量) +p(样本错误) × (样本错误时获得的信息量)。

2 感谢 Lukas Berglund 指出我此前在这一点上的阐述有误。

END

本期互动内容 🍻

❓人类从失败中能学到远不止“0/1”的反馈——你觉得 AI 系统要如何模拟这种过程性反思能力?

文中链接

[1]https://www.tobyord.com/writing/inefficiency-of-reinforcement...

[2]https://thinkingmachines.ai/blog/lora/#how-much-capacity-is-n...

[3]https://arxiv.org/pdf/2012.03107

[4]https://arxiv.org/pdf/1707.05300

[5]https://arxiv.org/abs/2501.12948

[6]https://arxiv.org/abs/2510.07364v3

[7]https://epoch.ai/data/ai-models

原文链接:

https://www.dwarkesh.com/p/bits-per-sample

注1:本文系"每天一个多模态知识点"专栏文章。本专栏致力于对多模态大模型/CV领域的高频高难面试题进行深度拆解。本期攻克的难题是:FlashAttention

注2:本文Markdown源码可提供下载,详情见文末

关注"每天一个多模态知识点"公众号,每天一个知识点的深度解析!


知识点13 | FlashAttention深度攻略:从IO复杂度理论到CUDA kernel实现的完整解析

面试原题复现

面试官提问:

"请解释FlashAttention算法的核心思想,并分析它相比传统注意力算法在计算复杂度和内存复杂度上的差异。为什么它能在不牺牲精度的情况下显著提升性能?请从IO复杂度、Tiling算法和在线Softmax三个维度进行详细阐述。"

关键回答(The Hook)

核心直觉:

FlashAttention的核心突破在于从计算思维转向IO思维。传统注意力算法关注浮点运算数量(FLOPs),而FlashAttention认识到在现代GPU上,内存访问带宽才是真正的性能瓶颈。通过精心设计的IO-aware算法,FlashAttention将注意力计算从compute-bound转变为memory-bound场景下的性能杀手,在保持数学精确性的前提下,将HBM访问量从O(N²)降低到O(N),从而实现了2-4倍的端到端加速和显著的长序列处理能力提升。


深度原理解析(The Meat)

1. 硬件性能瓶颈分析

首先需要理解现代GPU的内存层次结构。以NVIDIA A100为例:

GPU内存层次结构

GPU内存层次(以A100为例):

  • HBM(高带宽内存): 40-80GB容量,带宽1.5-2.0 TB/s
  • SRAM(片上静态随机存取存储器): 每个SM约192KB,带宽约19 TB/s

关键洞察: SRAM的带宽是HBM的10倍以上,但容量仅为HBM的1/1000。FlashAttention的核心策略就是:尽可能将数据驻留在SRAM中完成计算,减少HBM的访问次数

2. 标准注意力实现的性能分析

标准注意力计算公式:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

其中 $Q, K, V \in \mathbb{R}^{N \times d_k}$,N为序列长度,$d_k$为注意力头维度。

标准实现的内存访问模式:

# 标准注意力实现
def standard_attention(Q, K, V):
    # Step 1: 计算注意力矩阵 S = QK^T
    S = Q @ K.T  # O(N²d_k) FLOPs, O(N²) 内存
    # Step 2: 计算 softmax
    P = softmax(S)  # O(N²) 内存
    # Step 3: 加权求和
    O = P @ V  # O(N²d_k) FLOPs, O(N²) 内存
    return O

问题分析:

  1. 中间结果巨大: 注意力矩阵S和P都是N×N的,对于N=4096的序列,需要约67MB存储
  2. 多次HBM读写: 每个步骤都需要读写HBM,造成严重的带宽瓶颈
  3. 内存复杂度: 需要O(N² + Nd_k)的HBM访问

IO复杂度计算:

标准注意力需要:

  • 读取Q, K, V:$3Nd_k$ 次HBM访问
  • 写入和读取S:$2N^2$ 次HBM访问
  • 读取V和写入O:$Nd_k$ 次HBM访问

总计:Ω(N² + Nd_k) 次HBM访问

3. FlashAttention的核心创新:Tiling算法

FlashAttention通过分块策略,将整个注意力矩阵的计算分解为在SRAM中完成的小块计算:

FlashAttention分块架构

分块策略:

设块大小为$B_r$(行块)和$B_c$(列块),满足$B_r \cdot B_c \cdot d_k \leq \text{SRAM\_capacity}$

将Q, K, V分别划分为:

  • $Q = [Q_1, Q_2, ..., Q_{T_r}]$,其中 $Q_i \in \mathbb{R}^{B_r \times d_k}$
  • $K = [K_1, K_2, ..., K_{T_c}]$,其中 $K_j \in \mathbb{R}^{B_c \times d_k}$
  • $V = [V_1, V_2, ..., V_{T_c}]$,其中 $V_j \in \mathbb{R}^{B_c \times d_k}$

算法流程:

def flash_attention(Q, K, V):
    # 初始化
    O = zeros_like(Q)  # 输出矩阵
    m = -inf * ones(N)  # 每行的最大值
    l = zeros(N)  # 每行的指数和(归一化因子)
    
    # 外层循环:遍历K和V的块
    for j in range(T_c):
        K_block = K[j*B_c:(j+1)*B_c, :]
        V_block = V[j*B_c:(j+1)*B_c, :]
        
        # 内层循环:遍历Q的块
        for i in range(T_r):
            Q_block = Q[i*B_r:(i+1)*B_r, :]
            
            # 在SRAM中计算
            S_block = Q_block @ K_block.T  # B_r × B_c
            
            # 在线softmax更新
            m_new = max(m[i*B_r:(i+1)*B_r], rowmax(S_block))
            l_new = l[i*B_r:(i+1)*B_r] * exp(m[i*B_r:(i+1)*B_r] - m_new) + \
                     sum(exp(S_block - m_new), dim=1)
            
            # 缩放累加输出
            O[i*B_r:(i+1)*B_r, :] = O[i*B_r:(i+1)*B_r, :] * \
                                       exp(m[i*B_r:(i+1)*B_r] - m_new) + \
                                       softmax(S_block) @ V_block
            
            # 更新统计量
            m[i*B_r:(i+1)*B_r] = m_new
            l[i*B_r:(i+1)*B_r] = l_new
    
    # 最终归一化
    O = O / l.reshape(-1, 1)
    return O

关键优势:

  1. 避免存储完整的注意力矩阵: 仅在SRAM中计算小块的注意力得分
  2. 减少HBM访问: 每个块只加载一次到SRAM,完成所有计算
  3. 精确计算: 通过在线softmax技巧,保证数学等价性

4. 在线Softmax的数学推导

这是FlashAttention最精妙的数学创新。传统softmax需要先知道全局最大值,而分块计算打破了这一依赖。

传统softmax:

$$\text{softmax}(x)_i = \frac{\exp(x_i)}{\sum_j \exp(x_j)}$$

数值稳定版本:

$$\text{softmax}(x)_i = \frac{\exp(x_i - \max_j x_j)}{\sum_j \exp(x_j - \max_j x_j)}$$

在线softmax推导:

假设已经处理了前j-1个块,已知:

  • $m_{j-1} = \max_{k=1}^{j-1} x_k$(前j-1个块的全局最大值)
  • $l_{j-1} = \sum_{k=1}^{j-1} \exp(x_k - m_{j-1})$(归一化因子)

现在处理第j个块,得到局部统计量:

  • $m_{local}^{(j)} = \max_{k \in \text{block}_j} x_k$
  • $l_{local}^{(j)} = \sum_{k \in \text{block}_j} \exp(x_k - m_{local}^{(j)})$

更新全局统计量:

新的全局最大值:
$$m_j = \max(m_{j-1}, m_{local}^{(j)})$$

新的归一化因子需要将前j-1个块的结果缩放到新的最大值下:

$$l_j = \sum_{k=1}^{j} \exp(x_k - m_j) = \sum_{k=1}^{j-1} \exp(x_k - m_j) + \sum_{k \in \text{block}_j} \exp(x_k - m_j)$$

利用指数的性质:

$$\sum_{k=1}^{j-1} \exp(x_k - m_j) = \sum_{k=1}^{j-1} \exp(x_k - m_{j-1}) \cdot \exp(m_{j-1} - m_j) = l_{j-1} \cdot \exp(m_{j-1} - m_j)$$

$$\sum_{k \in \text{block}_j} \exp(x_k - m_j) = \sum_{k \in \text{block}_j} \exp(x_k - m_{local}^{(j)}) \cdot \exp(m_{local}^{(j)} - m_j) = l_{local}^{(j)} \cdot \exp(m_{local}^{(j)} - m_j)$$

最终更新公式:

$$l_j = l_{j-1} \cdot \exp(m_{j-1} - m_j) + l_{local}^{(j)} \cdot \exp(m_{local}^{(j)} - m_j)$$

Softmax算法可视化

输出累积更新:

对于输出$O = \sum_i \text{softmax}(x_i) V_i$,同样需要在线更新:

$$O_{j} = O_{j-1} \cdot \frac{l_{j-1}}{l_j} \cdot \exp(m_{j-1} - m_j) + \text{softmax}_{local}^{(j)} \cdot V_{local}^{(j)}$$

这个技巧保证了分块计算的数值稳定性和数学精确性。

5. IO复杂度分析

FlashAttention的IO复杂度:

每个Q块需要遍历所有K/V块,因此:

  • HBM访问次数:$T_r \times T_c \times B_r \times d_k = N \times T_c \times d_k$

其中 $T_c = N / B_c$,块大小 $B_c \times B_r \approx M/d_k$(M为SRAM容量)

因此:
$$\text{HBM访问} = N \times \frac{N}{B_c} \times d_k = \frac{N^2 d_k^2}{M}$$

对比:

算法IO复杂度典型参数下的性能
标准注意力Ω(N² + Nd_k)基准
FlashAttentionO(N²d_k²/M)减少2-4倍HBM访问

对于典型参数:$d_k = 64$,$M = 192$KB:
$$\frac{N^2 d_k^2}{M} \approx \frac{N^2 \times 4096}{192 \times 1024} \approx 0.02 N^2$$

这解释了FlashAttention为何能实现如此显著的加速。

6. 反向传播的重计算策略

传统注意力需要存储中间结果用于反向传播:

$$\frac{\partial L}{\partial Q} = \left(\frac{\partial L}{\partial O} \cdot V^T - \sum_j P_{ij} \frac{\partial L}{\partial O}_{ij}\right) \cdot P \cdot \frac{1}{\sqrt{d_k}}$$

这需要存储完整的注意力矩阵P(O(N²)内存)。

FlashAttention的反向传播优化:

在正向传播中,只存储:

  • 输出矩阵O:O(Nd_k)
  • Softmax统计量m和l:O(2N)

反向传播时,重新计算注意力矩阵,而不是从HBM读取。虽然增加了计算量(2×FLOPs),但由于避免了O(N²)的HBM读取,反而更快

面试追问: 为什么重计算反而更快?

因为GPU的浮点运算速度远快于内存访问速度。以A100为例:

  • FP16矩阵乘法:312 TFLOPs/s
  • FP32浮点运算:19.5 TFLOPs/s
  • HBM带宽:1.5 TB/s

重计算增加的计算开销远小于减少的HBM访问开销。

7. 算子融合优化

FlashAttention将以下四个步骤融合为一个CUDA kernel:

  1. QK^T矩阵乘法
  2. Softmax计算(含掩码、Dropout)
  3. 与V矩阵乘法
  4. 激活函数等后处理

融合的优势:

  • 避免中间结果的HBM读写
  • 减少kernel launch开销
  • 充分利用Tensor Core
  • 提高数据局部性

代码手撕环节(Live Coding)

FlashAttention核心实现

import torch
import math

class FlashAttention(torch.nn.Module):
    def __init__(self, head_dim, block_size=128):
        super().__init__()
        self.head_dim = head_dim
        self.block_size = block_size
        self.scale = head_dim ** -0.5
        
    def forward(self, q, k, v, causal_mask=False):
        """
        q, k, v: (batch_size, seq_len, num_heads, head_dim)
        输出: (batch_size, seq_len, num_heads, head_dim)
        """
        batch_size, seq_len, num_heads, head_dim = q.shape
        
        # 重塑为 (batch_size * num_heads, seq_len, head_dim)
        q = q.transpose(1, 2).contiguous()
        k = k.transpose(1, 2).contiguous()
        v = v.transpose(1, 2).contiguous()
        
        q = q.view(-1, seq_len, head_dim)
        k = k.view(-1, seq_len, head_dim)
        v = v.view(-1, seq_len, head_dim)
        
        # 初始化输出和统计量
        o = torch.zeros_like(q)
        m = torch.full((q.shape[0], q.shape[1]), float('-inf'), 
                       device=q.device, dtype=q.dtype)
        l = torch.zeros((q.shape[0], q.shape[1]), 
                       device=q.device, dtype=torch.float32)
        
        # 分块计算
        num_blocks = (seq_len + self.block_size - 1) // self.block_size
        
        for i in range(num_blocks):
            start_i, end_i = i * self.block_size, min((i + 1) * self.block_size, seq_len)
            q_block = q[:, start_i:end_i, :]  # (B, B_r, d)
            
            for j in range(num_blocks):
                start_j, end_j = j * self.block_size, min((j + 1) * self.block_size, seq_len)
                
                # 因果掩码
                if causal_mask and j > i:
                    continue
                
                k_block = k[:, start_j:end_j, :]  # (B, B_c, d)
                v_block = v[:, start_j:end_j, :]  # (B, B_c, d)
                
                # 在SRAM中计算注意力分数 (B, B_r, B_c)
                s_block = torch.matmul(q_block, k_block.transpose(-2, -1)) * self.scale
                
                # 因果掩码(如果需要)
                if causal_mask:
                    mask = torch.triu(torch.ones(s_block.shape[1], s_block.shape[2]), 
                                     diagonal=1).bool().to(s_block.device)
                    s_block = s_block.masked_fill(mask, float('-inf'))
                
                # 在线softmax更新
                m_new = torch.maximum(m[:, start_i:end_i], s_block.max(dim=-1).values)
                l_new = l[:, start_i:end_i] * torch.exp(m[:, start_i:end_i] - m_new) + \
                        torch.exp(s_block - m_new.unsqueeze(-1)).sum(dim=-1)
                
                # 更新输出
                o[:, start_i:end_i, :] = o[:, start_i:end_i, :] * torch.exp(
                    m[:, start_i:end_i] - m_new).unsqueeze(-1) + \
                    torch.exp(s_block - m_new.unsqueeze(-1)).unsqueeze(-1) * v_block.unsqueeze(1)
                
                # 更新统计量
                m[:, start_i:end_i] = m_new
                l[:, start_i:end_i] = l_new
        
        # 最终归一化
        o = o / l.unsqueeze(-1)
        
        # 重塑回原始形状
        o = o.view(batch_size, num_heads, seq_len, head_dim)
        o = o.transpose(1, 2).contiguous()
        
        return o

# 使用示例
if __name__ == "__main__":
    batch_size = 2
    seq_len = 512
    num_heads = 8
    head_dim = 64
    
    q = torch.randn(batch_size, seq_len, num_heads, head_dim, 
                   device='cuda', dtype=torch.float16)
    k = torch.randn(batch_size, seq_len, num_heads, head_dim, 
                   device='cuda', dtype=torch.float16)
    v = torch.randn(batch_size, seq_len, num_heads, head_dim, 
                   device='cuda', dtype=torch.float16)
    
    flash_attn = FlashAttention(head_dim=head_dim, block_size=128).cuda()
    output = flash_attn(q, k, v, causal_mask=True)
    
    print(f"输出形状: {output.shape}")  # 应为 (2, 512, 8, 64)

标准注意力对比

def standard_attention(q, k, v, causal_mask=False):
    """
    标准注意力实现(用于对比)
    """
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.shape[-1])
    
    if causal_mask:
        mask = torch.triu(torch.ones(scores.shape[-2], scores.shape[-1]), 
                        diagonal=1).bool().to(scores.device)
        scores = scores.masked_fill(mask, float('-inf'))
    
    attn_weights = torch.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, v)
    
    return output

# 验证数值等价性
def test_equivalence():
    # 创建小规模测试数据
    batch_size, seq_len, num_heads, head_dim = 2, 64, 4, 32
    
    q = torch.randn(batch_size, seq_len, num_heads, head_dim, 
                   dtype=torch.float32, requires_grad=True)
    k = torch.randn(batch_size, seq_len, num_heads, head_dim, 
                   dtype=torch.float32, requires_grad=True)
    v = torch.randn(batch_size, seq_len, num_heads, head_dim, 
                   dtype=torch.float32, requires_grad=True)
    
    # FlashAttention
    flash_attn = FlashAttention(head_dim=head_dim, block_size=32)
    flash_output = flash_attn(q, k, v, causal_mask=True)
    
    # 标准注意力
    q_std = q.transpose(1, 2)
    k_std = k.transpose(1, 2)
    v_std = v.transpose(1, 2)
    std_output = standard_attention(q_std, k_std, v_std, causal_mask=True).transpose(1, 2)
    
    # 比较数值
    diff = (flash_output - std_output).abs().max()
    print(f"最大数值差异: {diff.item()}")
    
    # 应该小于1e-5(允许数值误差)
    assert diff < 1e-5, f"数值差异过大: {diff}"
    
    print("✓ FlashAttention与标准注意力数值等价")

test_equivalence()

进阶追问与展望

1. 面试官可能的深度追问

追问1:FlashAttention-2相比FlashAttention-1有哪些改进?

回答要点:

  • 减少非matmul FLOPs: 优化了在线softmax的实现,减少缩放操作
  • 改进并行策略: 在序列维度上并行,提高GPU利用率
  • 优化工作负载分配: 从split-K改为split-Q,减少共享内存通信

追问2:FlashAttention的局限性和改进方向?

局限性:

  • 不适用于序列长度极大的情况(如N > 10^6)
  • 对硬件有特定要求,需要足够的SRAM
  • 实现复杂度高,需要精细的CUDA优化

改进方向:

  • FlashAttention-2/3: 进一步优化并行策略和硬件适配
  • Sparse FlashAttention: 结合稀疏注意力,处理超长序列
  • 硬件协同设计: 为特定GPU架构定制优化

追问3:如何选择最优的块大小?

考虑因素:

  • SRAM容量限制
  • GPU利用率
  • 寄存器压力
  • 共享内存bank冲突

经验法则:

$$B_{optimal} \approx \sqrt{\frac{M}{d_k \times 4}}$$

其中4是考虑数据类型(float16)和额外的存储开销。

2. 与其他优化方法的对比

方法核心思想精度适用场景
FlashAttentionIO-aware tiling + 在线softmax精确通用长序列处理
Linear Attention低秩近似近似超长序列(N > 10^5)
Sparse Attention结构化稀疏模式近似长序列 + 局部依赖
ReformerLocality Sensitive Hashing近似大规模序列建模
Performer随机特征近似近似理论研究、小规模应用

3. 前沿研究方向

当前热点:

  1. FlashAttention-3: 针对H100架构的深度优化,利用新的硬件特性
  2. 异步计算: 将计算与内存传输重叠,进一步提高吞吐量
  3. 混合精度优化: 结合FP8等低精度格式,进一步提升性能
  4. 多模态应用: 扩展到视觉-语言多模态注意力计算

应用前景:

  • 超大规模语言模型训练(GPT-4级别)
  • 长文档理解与问答
  • 实时视频处理
  • 生物学序列分析(如蛋白质折叠)

面试避坑指南

常见错误1: 认为FlashAttention是近似算法

纠正: FlashAttention是精确算法,它不进行任何近似,通过精确的数学推导保证与标准注意力数值等价。

常见错误2: 忽略IO复杂度,只谈论FLOPs

纠正: 现代GPU上,IO瓶颈往往比计算瓶颈更严重。FlashAttention的核心贡献是将注意力算法的设计焦点从FLOPs转向IO复杂度。

常见错误3: 混淆Tiling和Block-Sparse Attention

纠正: Tiling是实现技巧,用于减少内存访问;Block-Sparse是算法近似,用于降低计算复杂度。FlashAttention使用Tiling但不进行近似。

常见错误4: 认为重计算总是更慢

纠正: 在计算密集型和内存受限的场景下,重计算可能更快,因为避免了昂贵的内存访问。


总结与展望

FlashAttention代表了深度学习算法设计的范式转变:从单纯追求计算效率到统筹考虑硬件特性的系统级优化。它不仅仅是一个算法改进,更是算法-硬件协同设计的典范。

核心价值:

  1. 理论突破: 引入IO复杂度分析,建立新的算法评价标准
  2. 工程创新: Tiling + 在线softmax + 算子融合的组合优化
  3. 实用价值: 在保持精度的前提下,显著提升实际训练和推理性能
  4. 方法论影响: 启发了后续IO-aware算法的研究方向

面试要点回顾:

  • 理解GPU内存层次结构和性能瓶颈
  • 掌握Tiling算法的核心思想
  • 能够推导在线softmax的更新公式
  • 理解IO复杂度分析方法
  • 清晰对比FlashAttention与传统注意力的差异
  • 了解FlashAttention-2/3的改进方向

FlashAttention的成功证明了:优秀的算法设计需要深入理解硬件特性,在数学理论和工程实现之间找到最佳平衡点。 这也是系统AI研究的重要方向。


参考文献

  1. Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.
  2. Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.
  3. Saha, B., & Ye, C. (2024). The I/O Complexity of Attention, or How Optimal is Flash Attention?
  4. NVIDIA A100 GPU Architecture Whitepaper
  5. Hong, S., & Kim, H. (2020). An Analysis of Deep Learning Neural Networks with PyTorch.

谢谢阅读~

关注"每天一个多模态知识点"公众号,回复"FlashAttention"即可下载本文markdown源码

本文由mdnice多平台发布