SageAttention3: Microscaling FP4 Attention for Inference and An Exploration of 8-Bit Training
(2505.11594v1)
Published 16 May 2025 in cs.LG, cs.AI, cs.AR, cs.CV, and cs.PF
Abstract: The efficiency of attention is important due to its quadratic time complexity. We enhance the efficiency of attention through two key contributions: First, we leverage the new FP4 Tensor Cores in Blackwell GPUs to accelerate attention computation. Our implementation achieves 1038 TOPS on RTX5090, which is a 5x speedup over the fastest FlashAttention on RTX5090. Experiments show that our FP4 attention can accelerate inference of various models in a plug-and-play way. Second, we pioneer low-bit attention to training tasks. Existing low-bit attention works like FlashAttention3 and SageAttention focus only on inference. However, the efficiency of training large models is also important. To explore whether low-bit attention can be effectively applied to training tasks, we design an accurate and efficient 8-bit attention for both forward and backward propagation. Experiments indicate that 8-bit attention achieves lossless performance in fine-tuning tasks but exhibits slower convergence in pretraining tasks. The code will be available at https://github.com/thu-ml/SageAttention.
Summary
The paper introduces SageAttention3, which uses FP4 microscaling quantization to accelerate inference with up to 11× kernel speedup on modern GPUs.
It explores INT8 training via SageBwd, quantizing most operations while preserving gradient precision by retaining FP16 for key multiplications.
The study implements hardware-level optimizations in CUDA and Triton to improve both inference latency and fine-tuning accuracy on state-of-the-art RTX architectures.
The paper "SageAttention3: Microscaling FP4 Attention for Inference and An Exploration of 8-Bit Training" (2505.11594) addresses the critical need for efficient attention mechanisms in large generative models, primarily focusing on improving both inference and training speed through low-bit quantization. The core motivation stems from the quadratic time complexity of attention with respect to sequence length and the potential of modern hardware with low-bit Tensor Cores to accelerate computation.
The work presents two main contributions:
SageAttention3: A novel FP4 attention implementation designed specifically for accelerating inference by leveraging the new FP4 Tensor Cores available in GPUs like the RTX5090 (Blackwell architecture).
SageBwd: An exploration into applying low-bit (INT8) quantization to attention for training tasks, proposing an efficient and accurate approach for both forward and backward passes.
SageAttention3 for Inference
SageAttention3 focuses on accelerating the two matrix multiplications in attention, QK⊤ and PV, using FP4 microscaling quantization. The paper highlights key challenges when applying low-bit quantization to attention and proposes solutions:
Challenge C1 (FP4 value limitation): FP4 has only 15 representable values, making standard per-tensor or per-token quantization insufficient for accuracy.
Solution: The paper uses microscaling FP4 quantization with a group size of 1x16 for both QK⊤ and PV. This fine-grained scaling helps contain outlier effects and improves accuracy compared to coarser quantization granularities like per-tensor or per-channel. The quantization ϕ and dequantization ϕ−1 for a matrix block Xij∈R1×n are defined as:
ϕ:sij=max(∣X∣)/6,X^ij=⌈Xij/sij⌋
ϕ−1:Xij′=sij×X^ij
The implementation utilizes the NVFP4 format (E2M1 with E4M3 scale factors) over MXFP4 (E2M1 with E8M0 scale factors) due to empirically shown higher accuracy for attention.
Challenge C2 (Attention map scale factor range): The attention map P (after softmax) has values in [0,1]. Direct FP4 quantization of P results in small scale factors (max(P)/6∈[0,0.167]), which are then converted to FP8 (E4M3). This conversion maps the small scale factors to a narrow range within E4M3, leading to significant accuracy loss.
Solution: A two-level quantization for P is introduced. First, each row of P is scaled by its row maximum divided by a constant (448×6) to map the values to a larger range. Then, standard FP4 microscaling quantization is applied to this scaled P. This strategy ensures that the scale factors for the second level of quantization are distributed over a wider, more effective range of the E4M3 format, improving accuracy.
The final PV computation becomes $O = FP4MM(\hat P_2, \mathbf{s_{P_2}, \hat V, \mathbf{s_V}) \times \mathbf{s_{P_1}$.
SageAttention3 builds upon FlashAttention's tiled approach and incorporates techniques from SageAttention2 like smoothing Q and K. Hardware-level optimizations implemented in CUTLASS and CUDA include permuting K's columns to match the accumulator layout, reusing maximum values from online softmax for efficient P quantization scale calculation, and a novel producer warp epilogue design for overlapping computation and memory operations within register constraints.
Implementation and Performance (Inference):
Implemented using CUTLASS and CUDA.
Achieves up to 1038 TOPS on RTX5090, demonstrating a 5× kernel speedup over FlashAttention2 and 11× over xformers on the same hardware (RTX5090).
Maintains end-to-end quality metrics across various models (CogvideoX, HunyuanVideo, Mochi, Flux, Stable-Diffusion3.5) with negligible degradation.
Provides significant end-to-end inference latency reduction, achieving ~2.4× to 3× speedups for video generation models.
SageBwd for Training
SageBwd explores applying INT8 quantization to the attention mechanism during both forward and backward passes, which is unprecedented in prior low-bit attention research that focused solely on inference. The primary challenge here is the sensitivity of gradients to quantization errors.
Challenge C3 (Gradient sensitivity): Quantization errors, particularly in the backward pass, can accumulate and significantly degrade the accuracy of gradients, especially for dQ and dK. The paper identifies dP=dOV⊤ as a particularly sensitive multiplication.
Solution:
Forward Pass: Uses INT8 per-block quantization for QK⊤. For PV, it employs INT8 per-token quantization for P and INT8 per-block for V. It reuses max values from the online softmax for efficient P quantization.
Backward Pass: Quantizes six out of the seven matrix multiplications to INT8 using per-block quantization. However, the crucial matrix multiplication dP=dOV⊤ is kept in FP16 precision. This decision is based on empirical findings showing that quantizing this specific operation severely impacts dQ and dK accuracy, which accumulates errors during the backward pass. Keeping dOV⊤ in FP16 prevents this error accumulation and maintains gradient accuracy.
The INT8 per-block quantization ψ for a block X is sX=max(∣X∣)/127,X^=X/sX.
Implementation and Performance (Training):
Implemented using OpenAI Triton.
Achieves a 1.67× kernel speedup at most over FlashAttention2 (CUDA) on RTX4090 for the combined forward+backward pass.
Shows lossless performance in fine-tuning tasks on Qwen2.5 and Llama3.2 models across various benchmarks (GSM8K, DROP, MMLU, HELLASWAG), matching the accuracy of BF16 training.
Demonstrates slower convergence speed compared to BF16 in pretraining tasks on a smaller Llama model, indicating limitations for pretraining currently.
Achieves end-to-end training acceleration, speeding up one iteration for Llama 1B with 8K/16K sequence lengths by about 1.15× on RTX4090.
Implementation Considerations
SageAttention3 requires hardware support for FP4 Tensor Cores (e.g., Blackwell GPUs like RTX5090).
SageBwd leverages INT8 Tensor Cores, available on many modern GPUs (e.g., RTX4090).
Implementing these requires deep knowledge of low-level GPU programming, utilizing frameworks like CUTLASS and Triton to orchestrate efficient matrix multiplications, data movement, and quantization kernels.
The choice between NVFP4 and MXFP4 (for FP4) and per-block/per-token/two-level scaling is critical for balancing speed and accuracy.
The decision to keep dOV⊤ in FP16 during the backward pass of SageBwd highlights the non-trivial challenge of quantizing gradients and the need for careful analysis of error propagation.
Limitations and Future Work
The paper identifies two key areas for future work:
Optimizing the Triton kernels for SageBwd to further close the gap between current performance and theoretical peak speed.
Further research into the application of low-bit attention for pretraining tasks to address the slower convergence observed with SageBwd.
In summary, the paper provides practical methods for accelerating attention computation using low-bit quantization, offering state-of-the-art FP4 performance for inference on new hardware and demonstrating the feasibility and challenges of applying low-bit attention to accelerate fine-tuning.