Papers
Topics
Authors
Recent
2000 character limit reached

FlashAttention-2: Optimized Transformer Attention

Updated 24 December 2025
  • FlashAttention-2 is an advanced attention kernel that enhances transformer efficiency by optimizing memory access patterns and reducing scalar computation overhead.
  • It introduces innovations such as online softmax fusion, 2D tiling for enhanced GPU occupancy, and split-Q partitioning to eliminate costly shared memory reductions.
  • Empirical benchmarks show substantial speedups, achieving up to 3× faster performance on modern GPUs compared to traditional and previous FlashAttention approaches.

FlashAttention-2 is an optimized attention kernel designed to address inefficiencies in transformer attention computation, particularly for long sequence processing. It extends the IO-aware tiling and online softmax methodology of FlashAttention, with significant enhancements in parallelism, memory access patterns, and arithmetic efficiency. Empirical benchmarks demonstrate substantial kernel- and model-level speedups across a range of hardware and sequence lengths, closing the gap to raw matrix multiplication efficiency on modern GPUs (Dao, 2023).

1. Background and Baseline: FlashAttention

The transformer architecture computes, for each attention head and sequence of length NN with head dimension dd:

  • S=QKRN×NS = QK^\intercal \in\mathbb{R}^{N \times N}
  • P=softmax(S)RN×NP = \mathrm{softmax}(S) \in\mathbb{R}^{N \times N}
  • O=PVRN×dO = PV \in\mathbb{R}^{N \times d}

A naïve implementation requires O(N2)O(N^2) memory for SS and PP, with O(N2d)O(N^2 d) FLOPs and O(N2)O(N^2) memory moves dominating runtime. FlashAttention executes attention in an on-chip tiled loop, avoiding redundant off-chip memory traffic by fusing online softmax computation with GEMM and never externalizing SS or PP. This reduces memory usage to O(Nd)O(N d) and provides a 2–4× speedup over tuned PyTorch/CuDNN baselines. However, on NVIDIA A100 GPUs, this approach only achieves 25–40% of peak TFLOP/s due to limited occupancy (one thread block per head) and inefficient intra-block shared memory reductions (split-K partitioning).

2. Core Algorithmic Enhancements in FlashAttention-2

FlashAttention-2 introduces three principal modifications targeting arithmetic overhead and parallel scheduling to close the efficiency gap:

2.1. Reduction of Non-Matmul Computation

Scalar operations such as exp\exp, log\log, and pointwise scaling are an order of magnitude slower than fused matmul on GPUs. FlashAttention-2 combines per-block scalings into a single final normalization and stores only a combined log-sum-exp accumulator per output row. Specifically, it defines

  • O~(j)=exp(m(j1)m(j))O~(j1)+exp(S(j)m(j))V(j)\tilde{O}^{(j)} = \exp(m^{(j-1)} - m^{(j)})\,\tilde{O}^{(j-1)} + \exp(S^{(j)} - m^{(j)}) V^{(j)}
  • O=diag(exp(L(Tc)))O~(Tc)O = \mathrm{diag}(\exp(-L^{(T_c)}))\,\tilde{O}^{(T_c)}

where L(j)=m(j)+log(j)L^{(j)} = m^{(j)} + \log \ell^{(j)}. This saves one scaling multiply per block and reduces the total number of scalar FLOPs by O(Nd)O(N d) per forward pass.

2.2. Exposing Parallelism Over Sequence Blocks

Rather than launching a single thread block per (batch, head) pair, FA2 partitions the NN rows of QQ into Tr=N/BrT_r = \lceil N/B_r \rceil row blocks, with each block handled independently. A thread block is launched for each combination of (row block, head, batch), maximizing GPU occupancy even when the product of batch size and head count is small. The backward pass similarly partitions over column blocks of KK/VV. This 2D tiling ensures TrbatchheadsT_r \cdot \textrm{batch} \cdot \textrm{heads} thread blocks, saturating available multiprocessors when N/Br1N/B_r \gg 1.

2.3. Warp Partitioning: Split-Q Instead of Split-K

Within each thread block, standard FlashAttention used a split-K strategy, dividing KK columns among warps and performing explicit reductions in shared memory. FlashAttention-2 instead assigns disjoint QQ row subsets to each warp, with all warps broadcasting and independently accessing the full KK and VV blocks. This fuse-execute approach eliminates intra-block shared-memory communication and synchronization, reducing cross-warp traffic by O(WBrd)O(W B_r d) per block and increasing throughput by up to 15–20%. The same split-Q partitioning is applied to backward gradients to maintain locality and efficiency.

3. Complexity, Memory, and Asymptotic Analysis

The computational and memory complexity of the main attention variants is given in the following table:

Variant Total FLOPs Extra Memory Main Bottleneck
Standard O(N2Hd)O(N^2 H d) O(N2Hd)O(N^2 H d) Quadratic memory I/O
FlashAttention O(4N2Hd)O(4 N^2 H d) O(NHd)O(N H d) Scalar FLOPs, shared memory
FlashAttention-2 O(N2Hd)O(N^2 H d) matmuls O(NHd)O(N H d) Pointwise/exp/log + DRAM sync

FlashAttention-2 reduces the constant factors in both scalar arithmetic and shared memory traffic by approximately 2×2\times relative to FlashAttention, while preserving the matmul efficiency. The per-output-row state storage requirement is reduced (one float LL instead of two).

4. Empirical Performance and Model-Level Results

FlashAttention-2 achieves substantial speedups on the NVIDIA A100-80GB GPU in both kernel-level and end-to-end training settings.

4.1. Kernel-Level TFLOP/s

Measured at Br=Bc=128B_r=B_c=128, d=64d=64 or $128$, with/without causal masking:

  • Forward: 50–73% of the 312 TFLOP/s peak (160–230 TFLOP/s)
  • Backward: 30–63% of peak (95–195 TFLOP/s)
  • Combined: 100–200 TFLOP/s end-to-end

This represents roughly 2×2\times the performance of FlashAttention and 1.3–2×\times that of Triton/xFormers implementations, with further gains under causal masking and long sequences.

4.2. GPT Training Throughput

End-to-end model FLOP/s (per GPU) on GPT-3 models:

Model & Context No FlashAttn FlashAttn FlashAttn-2
1.3B @ 2K tokens 142 TFLOP/s 189 TFLOP/s 196 TFLOP/s (~62%)
1.3B @ 8K tokens 72 TFLOP/s 170 TFLOP/s 220 TFLOP/s (~70%)
2.7B @ 2K tokens 149 TFLOP/s 189 TFLOP/s 205 TFLOP/s (~65%)
2.7B @ 8K tokens 80 TFLOP/s 175 TFLOP/s 225 TFLOP/s (~72%)

FlashAttention-2 is consistently 1.3×\times faster than FlashAttention and up to 3×3\times faster than no-FlashAttn baselines.

5. Efficiency Relative to Matrix Multiplication and Remaining Obstacles

Whereas optimized GEMM reaches 80–90% of theoretical peak due to fine-grained register/SRAM usage and lack of scalar operations, FlashAttention-2 achieves 50–73%. The dominant limiting factors are:

  • Residual non-matmul operations (exp/log/pointwise scaling) capped at ~19 TFLOP/s
  • Atomic additions for dQdQ in the backward pass introduce DRAM stalls
  • Register pressure and shared-memory allocation restrict block sizes to 64–128

Hardware advances such as H100’s TMA (async DMA), fourth-generation Tensor Cores with fp8, and faster exp/log primitives are projected to further reduce these inefficiencies. A plausible implication is that algorithmic improvements combined with such hardware support will approach the efficiency of pure GEMM.

6. Future Directions and Applications

FlashAttention-2 provides a reference kernel for high-throughput transformer computations and sets a foundation for further enhancements. Integrating sparse and local attention patterns into the FA2 kernel may yield even greater effective throughput at the model level, especially as sequence lengths scale. The design principles of work partitioning, online normalization, and hardware-aligned computation established by FlashAttention-2 serve as a basis for next-generation attention algorithms targeting emerging deep learning hardware (Dao, 2023).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Whiteboard

Follow Topic

Get notified by email when new papers are published related to FlashAttention2.