Papers
Topics
Authors
Recent
Search
2000 character limit reached

FlashAttention-4: Advanced GPU Attention Kernel

Updated 8 March 2026
  • FlashAttention-4 is an advanced attention kernel designed to overcome asymmetric hardware scaling, achieving up to 2.7× speedup on large-scale Transformer workloads.
  • It employs a fully asynchronous MMA pipeline with on-SM tensor memory and software-emulated exponentials to mitigate shared memory and non-MMA bottlenecks.
  • Implemented in Python-embedded CuTe-DSL, the approach drastically reduces kernel compile times and enables agile, high-performance co-design on Blackwell GPUs.

FlashAttention-4 is an advanced attention kernel and algorithmic framework designed to address the new computational and memory bottlenecks arising from asymmetric hardware scaling in NVIDIA's Blackwell-class datacenter GPUs (e.g., B200, GB200). It represents a significant evolution over FlashAttention-3 by explicitly targeting systems in which tensor core throughput outpaces that of other hardware resources, such as shared memory bandwidth and exponential units, which do not scale accordingly. This co-design leverages architectural features of Blackwell GPUs—specifically, fully asynchronous matrix-multiply-accumulate (MMA) operations, on-SM tensor memory (TMEM), polynomial-emulated exponential computation, and a 2-CTA MMA mode—to restore compute/memory balance and maximize utilization, achieving up to 1.3× speedup over cuDNN 9.13 and 2.7× over Triton for large-scale Transformer workloads (Zadouri et al., 5 Mar 2026). Developed entirely in Python-embedded CuTe-DSL, FlashAttention-4 also drastically reduces kernel compile times versus C++ metaprogramming, facilitating faster research and deployment cycles.

1. Motivation: Asymmetric Hardware Scaling on Blackwell GPUs

Unlike prior architectures where most functional units scaled uniformly, Blackwell GPUs exhibit "asymmetric scaling":

  • Tensor core throughput: doubled from 4096 to 8192 BF16 FLOPs/clock/SM compared to Hopper (H100).
  • Shared memory bandwidth: remains ~128 bytes/clock/SM.
  • Exponential (MUFU.EX2) units: fixed at 16 ops/clock/SM; no scaling.

Under this regime, traditional attention kernels become bottlenecked not by MMA operations, but by shared-memory traffic and the throughput of non-MMA operations required for softmax (i.e., exponentiation and rescaling). Roofline analysis for M=N=d=128M=N=d=128 shows compute and exponential time (1024\approx 1024 cycles each) slightly outpace shared memory (768\approx 768 cycles), while at larger tiles, memory cost dominates, constraining the efficiency gains possible from increased compute throughput alone (Zadouri et al., 5 Mar 2026).

2. Algorithmic and Kernel Innovations

FlashAttention-4 implements three primary strategies to address these bottlenecks:

a. Fully Asynchronous MMA Pipeline

By exploiting Blackwell's TMEM, accumulator tiles (128×128) are written directly to on-SM tensor memory asynchronously, freeing register files and enabling tiles twice the area of Hopper's 64×128. The pipeline involves:

  1. Tile loading: Q/K blocks are staged from global to shared memory.
  2. Asynchronous MMA (TMA) launch: Q·Kᵀ products computed into TMEM.
  3. Softmax computation: Consumer warps process the previous tile from TMEM, performing row-max, exponentiation (emulated or hardware), summation, and writing probabilities back to TMEM.
  4. Correction and output MMA: Conditional softmax rescaling is applied off the critical path if needed, followed by P·V MMA for output.

Synchronization primitives (__syncthreads) are minimized, relying on the natural overlap enabled by hardware asynchrony and inter-warp scheduling.

b. Software-Emulated Exponential and Conditional Softmax Rescaling

Given limited hardware exponential throughput, FlashAttention-4 introduces a pipelined emulation:

  • Exponent range reduction: xx is separated into integer and fractional (xfx_f) parts, with 2x=2x2xf2^x = 2^{\lfloor x \rfloor} \cdot 2^{x_f}.
  • Bit manipulation and polynomial evaluation: 2xf2^{x_f} is approximated using a degree-3 polynomial and FMA instructions, minimizing register usage.
  • Selective emulation: Only 10–25% of entries are emulated; the rest use hardware MUFU.EX2, balancing throughput and resource usage.
  • Softmax rescaling: Rescaling is performed only when the max change Δm\Delta m exceeds a threshold τ\tau (e.g., τ=8.0\tau=8.0), skipping unnecessary vector multiplications and maintaining exact output normalization.

c. Optimized Backward Pass: TMEM & 2-CTA MMA

Backward attention involves five MMAs per block and significant shared memory traffic. To address this:

  • TMEM usage: Four accumulator tiles stored in TMEM to reduce register pressure and enable reuse in paired MMAs.
  • Three-stage pipeline: Prologue, main loop, and epilogue structure overlaps MMAs and elementwise kernels.
  • 2-CTA cooperative MMA: Two CTAs issue a single collective MMA on a 256×128 tile, halving shared memory loads and atomic adds.
  • Distributed shared-memory exchange: DSMEM exchanges ensure each CTA has a global reduction perspective, thus maintaining algorithm correctness and halving global atomic operations.

3. Implementation in CuTe-DSL and Development Workflow

FlashAttention-4 is implemented entirely in NVIDIA’s Python-embedded CuTe-DSL, which provides direct PTX-level abstractions aligned with CUTLASS:

  • JIT Compilation Efficiency: Forward kernel compiles in 2.5 s and backward in 1.4 s, compared to 55 s and 45 s for C++ CUTLASS templates.
  • Low-level Control: Maintains full access to custom PTX intrinsics for optimal scheduling and resource utilization.
  • Modularity: Primitives for tiling, masking, variable-length scheduling, and block sparsity compose orthogonally, facilitating extensions to novel layouts and attention variants.

This approach enables faster developer iteration and supports research agility without compromising kernel performance or hardware utilization (Zadouri et al., 5 Mar 2026).

4. Performance Evaluation and Complexity Analysis

Quantitative Benchmarks

Evaluated on B200 180 GB SXM6 (CUDA 13.1, BF16):

  • Forward Pass: 1.1–1.3× faster than cuDNN 9.13; 2.1–2.7× over Triton; up to 1613 TFLOPs/s (71% of Blackwell’s peak of 2250 TFLOPs/s).
  • Backward Pass: 1.2–1.4× over cuDNN; 2-CTA mode yields up to 14% additional speed; deterministic backward (with semaphore ordering) achieves ∼75% of nondeterministic speed.
  • Latency Example: For a 32K-token batch (head_dim=128, 16 heads), FlashAttention-4 achieves 5.4 ms (forward+backward) vs 7.0 ms for cuDNN.

Theoretical Cost Estimates

Per block (cycles/SM):

  • TMMA=4MNdCtensorT_{MMA} = \frac{4 M N d}{C_{tensor}}
  • Tsmem=3MNdbytes_per_cycle/bytes_per_elementT_{smem} = \frac{3 M N d}{bytes\_per\_cycle / bytes\_per\_element}
  • Texp=MNCexpT_{exp} = \frac{M N}{C_{exp}} (with Ctensor=8192C_{tensor}=8192 FLOPs/clock, Cexp=16C_{exp}=16 ops/clock)

Ideal pipeline overlap leads to wall-time approaching max(TMMA,Tsmem,Texp)\max(T_{MMA}, T_{smem}, T_{exp}). The 2-CTA mode further reduces shared-memory load in the backward by doubling tile width.

A key context for FlashAttention-4’s development is the broader trend identified by 3D-Flow and fine-grained 3D-FlashAttention co-designs (Yu et al., 11 Feb 2026):

  • On-chip SRAM bottleneck: As off-chip traffic is reduced via algorithmic fusion, on-chip SRAM energy can dominate attention costs.
  • 3D spatial architectures: Hybrid-bonded, register-to-register communication across vertical PE tiers (using sub-10 μm TSVs) achieves cycle-exact, bubble-free pipelining. 3D-FlashAttention maps the QKᵀ/softmax/PV stages directly onto a 4-layer systolic cube, obviating unnecessary buffer round-trips and achieving up to 7.6× speedup and 93% energy savings versus 2D baselines.
  • Future directions: Prospects include fusing additional operations and exploiting heterogeneous bit-widths per tier. The algorithm/hardware boundary becomes blurred as attention primitives move toward fine-grained pipelined operators rather than memory-bound routines.

FlashAttention-4's approach is distinguished from alternate attention accelerators such as hybrid floating-point/logarithmic hardware (H-FA):

  • H-FA (Alexandridis et al., 31 Oct 2025) converts the softmax-multiply step into a Logarithmic Number System (LNS), eliminating wide FP multiplies/divides and explicit exponent computation, yielding area and power reductions (>25% and >23% respectively) with negligible accuracy loss. By contrast, FlashAttention-4 targets rapid adaptation to asymmetric hardware and software-level pipelining on general purpose GPUs, rather than custom ASIC datapaths.
  • 3D-FlashAttention (Yu et al., 11 Feb 2026) targets custom 3D NPUs where dataflow is vertically pipelined through register-only links, mapping FlashAttention's stages directly onto physical compute tiers, further pushing memory efficiency and resource utilization.
Technique Target Hardware Core Innovation
FlashAttention-4 Blackwell GPUs Async MMA, TMEM, CuTe-DSL, 2-CTA
H-FA Custom ASIC/NPU LNS datapath, fused softmax/MV
3D-FlashAttention 3D-stacked NPU Vertical pipelining, PE tier-mapping

This table summarizes distinguishing properties among representative attention acceleration methods as reported in the respective references.

7. Implications and Future Prospects

The co-design of FlashAttention-4 typifies the convergence of hardware-aware algorithm engineering, leveraging fully asynchronous MMA, programmable on-SM storage, and selective functional unit emulation to restore kernel balance on next-generation architectures (Zadouri et al., 5 Mar 2026). Emerging lines of research suggest further gains via 3D spatial co-design, vertical dataflow, and dynamic tier resource allocation (Yu et al., 11 Feb 2026). A plausible implication is that future attention kernels will increasingly depend on precise interplay with underlying hardware, eschewing strict separation between algorithmic logic and memory management in favor of globally scheduled, pipelined operators.

FlashAttention-4’s modular, Python-based implementation framework also signals a shift toward more agile, rapid iteration on complex kernel designs without sacrificing low-level performance or control. As asymmetric and heterogeneous hardware scaling become commonplace, such flexible methodologies and tightly co-designed primitives are expected to underlie high-performance attention in both general-purpose and specialized compute environments.

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to FlashAttention-4.