Efficient AI · Transformers

FlashAttention Explained: IO-Aware Exact Attention, 2-4x Faster

FlashAttention is an exact attention algorithm that uses tiling and recomputation to cut GPU memory traffic, delivering 3x on GPT-2, 15% on BERT-large, and linear memory in sequence length.

FlashAttention Explained: IO-Aware Exact Attention, 2-4x Faster

Quick answer

FlashAttention is an exact attention algorithm — same math, same outputs as standard attention — that runs 2-4x faster by minimizing reads and writes between GPU high-bandwidth memory (HBM) and on-chip SRAM. The paper reports a 3x end-to-end speedup on GPT-2 (sequence length 1K), 15% over the MLPerf 1.1 BERT-large record (length 512), and 2.4x on Long-Range Arena (length 1K-4K). Crucially, its memory grows linearly, not quadratically, with sequence length.

The memory bottleneck in attention

Standard attention is slow on long sequences because its compute and memory scale quadratically with length. The usual fix is approximate attention — sparse or low-rank schemes that cut FLOPs. FlashAttention’s central argument is that FLOPs were never the real bottleneck on a GPU: most approximate methods reduce arithmetic but show little or no wall-clock speedup, because the cost is dominated by moving the large N×N attention matrix in and out of HBM.

GPU memory is a hierarchy. SRAM on the chip is fast but tiny; HBM is large but comparatively slow. Standard attention materializes the full softmax matrix in HBM, reading and writing it repeatedly. The operation is memory-bound, so an algorithm that does the same FLOPs but touches HBM far less will run faster. This reframing — attention as an IO problem, not an arithmetic one — is the paper’s real contribution.

Tiling and recomputation

FlashAttention never writes the full attention matrix to HBM. It splits Q, K, and V into blocks, loads them into SRAM, and computes attention block by block. The hard part is softmax, which normally needs the whole row at once. FlashAttention uses the online-softmax trick: it keeps running max and sum statistics and rescales partial results as each new block arrives, producing the exact softmax without ever holding a full row in slow memory.

For the backward pass, storing every intermediate would defeat the purpose, so FlashAttention recomputes the attention blocks on the fly from the stored softmax statistics. Recomputation adds FLOPs but removes HBM traffic — a deliberate trade that wins because the kernel is memory-bound. The authors also prove the IO complexity is optimal across a range of SRAM sizes, so this is not just a good heuristic but a near-tight bound.

Key results

  • 3x end-to-end speedup training GPT-2 at sequence length 1K.
  • 15% wall-clock speedup on BERT-large (length 512) over the MLPerf 1.1 training record.
  • 2.4x speedup on Long-Range Arena (length 1K-4K).
  • Memory usage linear in sequence length, versus quadratic for standard attention.
  • Longer context yields higher quality: 0.7 lower perplexity on GPT-2 and 6.4 points of lift on long-document classification.
  • The first Transformers to beat chance on Path-X (length 16K, 61.4%) and Path-256 (length 64K, 63.1%) — tasks previously out of reach.
  • Block-sparse FlashAttention is faster than any prior approximate attention method.

Limits and open questions

FlashAttention speeds up attention; it does not change the quadratic compute dependency on sequence length — only the memory footprint becomes linear. So the FLOPs still grow with N², just without the HBM penalty. The speedup is also heavily hardware- and implementation-specific: the gains come from a hand-tuned CUDA kernel targeting a specific SRAM/HBM hierarchy, so they do not transfer for free to new accelerators or to frameworks that cannot fuse the operation. The block-sparse variant reintroduces approximation, with the usual quality caveats. And the original kernel did not fully saturate newer GPUs, which is exactly why FlashAttention-2 and FlashAttention-3 followed. The honest read: this is a systems result as much as an algorithms one, and its value lives or dies on careful engineering.

FAQ

Is FlashAttention exact or approximate?

The core FlashAttention algorithm is exact — it produces identical outputs to standard attention, just with far less GPU memory traffic. Only the optional block-sparse variant is approximate.

How much faster is FlashAttention?

The paper reports 2-4x speedups: 3x on GPT-2 (length 1K), 15% on BERT-large (length 512) over the MLPerf 1.1 record, and 2.4x on Long-Range Arena (length 1K-4K).

Why does FlashAttention save memory?

Because it never materializes the full N×N attention matrix in HBM. Tiling plus online softmax keeps memory linear in sequence length instead of quadratic, which is what enables 16K-64K-token contexts.

Does FlashAttention reduce FLOPs?

No — it actually adds FLOPs through recomputation in the backward pass. The win comes from cutting slow HBM reads and writes, since attention is memory-bound rather than compute-bound on GPUs.

Who should use FlashAttention?

Anyone training or serving Transformers on long sequences. It is now baked into mainstream training and inference stacks, so most practitioners use it indirectly without calling the API themselves.

FlashAttention won by treating memory movement, not arithmetic, as the real bottleneck — and that insight reshaped how Transformers run. Read the original: https://arxiv.org/abs/2205.14135