Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
41 tokens/sec
GPT-4o
60 tokens/sec
Gemini 2.5 Pro Pro
44 tokens/sec
o3 Pro
8 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (2205.14135v2)

Published 27 May 2022 in cs.LG
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Abstract: Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length. Approximate attention methods have attempted to address this problem by trading off model quality to reduce the compute complexity, but often do not achieve wall-clock speedup. We argue that a missing principle is making attention algorithms IO-aware -- accounting for reads and writes between levels of GPU memory. We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM. We analyze the IO complexity of FlashAttention, showing that it requires fewer HBM accesses than standard attention, and is optimal for a range of SRAM sizes. We also extend FlashAttention to block-sparse attention, yielding an approximate attention algorithm that is faster than any existing approximate attention method. FlashAttention trains Transformers faster than existing baselines: 15% end-to-end wall-clock speedup on BERT-large (seq. length 512) compared to the MLPerf 1.1 training speed record, 3$\times$ speedup on GPT-2 (seq. length 1K), and 2.4$\times$ speedup on long-range arena (seq. length 1K-4K). FlashAttention and block-sparse FlashAttention enable longer context in Transformers, yielding higher quality models (0.7 better perplexity on GPT-2 and 6.4 points of lift on long-document classification) and entirely new capabilities: the first Transformers to achieve better-than-chance performance on the Path-X challenge (seq. length 16K, 61.4% accuracy) and Path-256 (seq. length 64K, 63.1% accuracy).

The paper introduces FlashAttention, an IO-aware algorithm engineered to expedite and economize memory usage in exact attention computation. The central premise posits that conventional attention mechanisms are encumbered by excessive memory reads and writes between GPU's high bandwidth memory (HBM) and on-chip static RAM (SRAM), thereby diminishing overall efficiency.

The authors' analysis of FlashAttention's IO complexity reveals a reduction in HBM accesses relative to standard attention, achieving optimality across a spectrum of SRAM sizes. The algorithm's utility extends to block-sparse attention, culminating in an approximate attention algorithm exhibiting superior speed compared to existing methodologies. Empirical validation demonstrates that FlashAttention accelerates Transformer training, yielding substantial end-to-end wall-clock speedups on tasks such as BERT-large and GPT-2.

Key aspects and contributions detailed in the paper are:

  • IO-Aware Attention Mechanism: FlashAttention leverages tiling to mitigate the necessity of materializing large attention matrices on GPU HBM. By segmenting input sequences into blocks and iteratively processing these blocks, the algorithm diminishes memory I/O operations.
  • Performance Analysis: Rigorous analysis of FlashAttention's IO complexity underscores its efficiency, demonstrating fewer HBM accesses than standard attention.
  • Block-Sparse Extension: The paper introduces an extension to block-sparse attention, further enhancing computational speed and enabling processing of extended contextual sequences.
  • Empirical Validation: The efficacy of FlashAttention is substantiated through empirical evaluations, showcasing accelerated model training, improved model quality, and extended contextual capabilities in Transformer models.

The authors address the quadratic computational complexity concerning sequence length, inherent in self-attention mechanisms within Transformer models [Vaswani et al. 2017]. They highlight the limitations of existing approximate attention methods, which often prioritize floating-point operations (FLOPs) reduction while neglecting memory access overheads, thereby failing to achieve actual wall-clock speedup.

The algorithm is designed to circumvent HBM read/write operations of the attention matrix by:

  • Computing the softmax reduction without requiring complete input access.
  • Avoiding the storage of intermediate attention matrices during backward propagation.

These objectives are realized through tiling, which involves partitioning the input and performing multiple passes over these blocks, and through storing softmax normalization factors to facilitate rapid on-chip attention recomputation during backpropagation.

FlashAttention attains O(N2d2M1)O(N^2 d^2 M^{-1}) HBM accesses, where:

  • NN is sequence length
  • dd is head dimension
  • MM is SRAM size

This contrasts with the Ω(Nd+N2)\Omega(Nd + N^2) complexity of standard attention mechanisms.

The performance gains are attributed to a reduction in HBM accesses, despite an increased FLOPs count. Empirical results demonstrate a 7.6x speedup on GPT-2, attributable to eliminating the need to read and write large N×NN \times N attention matrices to HBM.

A lower bound is established, asserting that no exact attention algorithm can enhance the number of HBM accesses asymptotically across all SRAM sizes.

FlashAttention serves as a primitive for realizing the potential of approximate attention algorithms by overcoming memory access overhead. Block-sparse FlashAttention, a sparse attention algorithm, achieves 2-4×\times speedup compared to FlashAttention and scales to a sequence length of 64k. This is due to block-sparse FlashAttention exhibiting enhanced IO complexity compared to FlashAttention, by a factor proportional to the sparsity ratio.

Empirical assessments validate that FlashAttention accelerates model training, enhances model quality through modeling longer contexts, and diminishes the memory footprint. The algorithm trains BERT-large 15\% faster than the MLPerf 1.1 training speed record, accelerates GPT2 training by 3×\times, and achieves a 2.4×\times speedup on the long-range arena. Furthermore, it improves perplexity on GPT-2 by 0.7 and enhances long-document classification accuracy by 6.4 points.

The paper includes a background on GPU memory hierarchy, highlighting the disparity between compute speed and memory speed. Operations are classified as either compute-bound or memory-bound, with the latter being constrained by HBM accesses. Kernel fusion is presented as an approach to accelerate memory-bound operations by reducing HBM accesses.

The standard attention implementation is scrutinized for its memory inefficiency, particularly the materialization of matrices SS and PP to HBM. The authors show the HBM accesses scale quadratically with sequence length NN.

The analysis of IO complexity for FlashAttention reveals a substantial reduction in HBM accesses, exhibiting optimality for a range of SRAM sizes MM. A lower bound is established, indicating that no exact attention algorithm can asymptotically improve upon the number of HBM accesses across all SRAM sizes. The IO complexity of block-sparse FlashAttention is demonstrably superior to that of FlashAttention, with improvements proportional to the sparsity ratio.

The paper also details the memory-efficient forward and backward passes, with a focus on computing softmax normalization constants separately to decouple columns and reduce memory requirements. The full details of FlashAttention forward and backward passes are described, including the pseudo-random number generator state for dropout regularization.

The limitations of the present approach are rooted in the necessity of compiling to CUDA, which demands considerable engineering effort and may not be readily transferable across GPU architectures. Future directions involve IO-aware implementations of additional modules, multi-GPU IO-aware methods, and sparse MLP layers.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (5)
  1. Tri Dao (47 papers)
  2. Daniel Y. Fu (25 papers)
  3. Stefano Ermon (279 papers)
  4. Atri Rudra (55 papers)
  5. Christopher Ré (194 papers)
Citations (1,551)
Youtube Logo Streamline Icon: https://streamlinehq.com