FlashAttention 详解:IO 感知精确注意力,提速 2-4 倍
FlashAttention 是精确注意力算法,用分块和重计算压缩 GPU 内存搬运:GPT-2 提速 3 倍、BERT-large 提速 15%,显存随序列长度线性增长。
快速答案
FlashAttention 是一种精确注意力算法——数学和输出都与标准注意力完全一致——但通过最小化 GPU 高带宽显存(HBM)与片上 SRAM 之间的读写,把速度提升了 2-4 倍。论文报告:GPT-2(序列长度 1K)端到端提速 3 倍,BERT-large(长度 512)相比 MLPerf 1.1 记录快 15%,Long-Range Arena(长度 1K-4K)提速 2.4 倍。更关键的是,它的显存随序列长度线性增长,而非平方增长。
注意力真正的内存瓶颈
标准注意力在长序列上慢,是因为计算量和显存都随长度平方增长。常规思路是用近似注意力——稀疏或低秩方案来削减 FLOPs。FlashAttention 的核心判断是:在 GPU 上,FLOPs 从来都不是真正的瓶颈。大多数近似方法削减了算术量,却几乎没有 wall-clock 加速,因为成本主要来自把巨大的 N×N 注意力矩阵反复搬进搬出 HBM。
GPU 显存是一个层级结构:片上 SRAM 快但极小,HBM 大但相对慢。标准注意力会在 HBM 里物化完整的 softmax 矩阵,并反复读写它。这个操作是内存受限(memory-bound)的,所以一个做同样 FLOPs、但极少触碰 HBM 的算法就会更快。把注意力重新理解成一个 IO 问题、而不是算术问题——这才是这篇论文真正的贡献。
分块与重计算
FlashAttention 从不把完整注意力矩阵写进 HBM。它把 Q、K、V 切成块,加载到 SRAM,逐块计算注意力。难点在 softmax,因为它通常需要一整行数据。FlashAttention 用了 online softmax 技巧:维护运行中的最大值和求和统计量,每来一个新块就对部分结果重新缩放,从而在不持有完整行的前提下,算出精确的 softmax。
反向传播时,如果存下所有中间结果就违背了初衷,于是 FlashAttention 借助保存的 softmax 统计量,实时重新计算注意力块。重计算增加了 FLOPs,却省掉了 HBM 搬运——这是一个有意为之的取舍,而且因为 kernel 是内存受限的,这笔买卖划算。作者还证明了它的 IO 复杂度在一系列 SRAM 大小下是最优的,所以这不只是个好启发式,而是接近紧的下界。
关键结果
- GPT-2 序列长度 1K,端到端训练提速 3 倍。
- BERT-large(长度 512)相比 MLPerf 1.1 训练记录,wall-clock 提速 15%。
- Long-Range Arena(长度 1K-4K)提速 2.4 倍。
- 显存随序列长度线性增长,而标准注意力是平方增长。
- 更长上下文带来更高质量:GPT-2 困惑度降低 0.7,长文档分类提升 6.4 个点。
- 首批在 Path-X(长度 16K,61.4%)和 Path-256(长度 64K,63.1%)上超过随机水平的 Transformer——这些任务此前遥不可及。
- 块稀疏版 FlashAttention 比此前任何近似注意力方法都快。
局限与存疑
FlashAttention 加速了注意力,但没有改变随序列长度平方增长的计算依赖——只是显存变成了线性。也就是说,FLOPs 仍然随 N² 增长,只是去掉了 HBM 的惩罚。它的加速也高度依赖硬件和实现:收益来自针对特定 SRAM/HBM 层级手工调优的 CUDA kernel,因此无法免费迁移到新加速器,或迁移到无法做算子融合的框架上。块稀疏变体重新引入了近似,带着惯常的质量隐患。而且原始 kernel 并未完全吃满更新的 GPU,这正是后续 FlashAttention-2 和 FlashAttention-3 出现的原因。诚实地讲:这与其说是算法成果,不如说一半是系统工程成果,它的价值成败都系于精细的工程实现。
常见问题
FlashAttention 是精确的还是近似的?
核心 FlashAttention 算法是精确的——它产生与标准注意力完全相同的输出,只是 GPU 内存搬运少得多。只有可选的块稀疏变体才是近似的。
FlashAttention 到底快多少?
论文报告 2-4 倍提速:GPT-2(长度 1K)3 倍,BERT-large(长度 512)相比 MLPerf 1.1 记录快 15%,Long-Range Arena(长度 1K-4K)2.4 倍。
FlashAttention 为什么能省显存?
因为它从不在 HBM 里物化完整的 N×N 注意力矩阵。分块加 online softmax 让显存随序列长度线性增长而非平方增长,这正是支持 16K-64K token 上下文的关键。
FlashAttention 减少了 FLOPs 吗?
没有——它在反向传播中通过重计算反而增加了 FLOPs。收益来自削减缓慢的 HBM 读写,因为注意力在 GPU 上是内存受限而非计算受限的。
谁应该用 FlashAttention?
任何在长序列上训练或部署 Transformer 的人。它现在已被主流训练和推理栈内置,所以大多数从业者其实是在间接使用,不必自己调用接口。
FlashAttention 的胜利,在于它把内存搬运、而非算术量,当成了真正的瓶颈来攻——而这个洞见重塑了 Transformer 的运行方式。原文见:https://arxiv.org/abs/2205.14135