SeerAttention: Data-Driven Sparse Attention
- SeerAttention is a data-driven sparse attention mechanism that dynamically predicts block-wise sparsity using a learnable gating network to reduce LLM computational costs.
- It partitions sequences into blocks and uses pooling with linear gates to form a binary mask, enabling efficient block-sparse FlashAttention without static heuristics.
- Empirical evaluations show up to 5.47× speedup and minimal accuracy loss, with its autoregressive variant SeerAttention-R further optimizing decoding performance.
SeerAttention is a data-driven sparse attention mechanism for LLMs, designed to address the prohibitive computational and memory costs of dense Transformer attention on long-context sequences. By introducing a learnable gating network that predicts block-wise sparsity in the attention map, SeerAttention achieves high efficiency while maintaining model accuracy. This approach obviates the need for static heuristics or rule-based sparsity and extends naturally to both batch prefill and auto-regressive decoding settings through its decoding-optimized sibling, SeerAttention-R (Gao et al., 2024, Gao et al., 10 Jun 2025).
1. Motivation for Sparse Attention in LLMs
Transformer attention modules have quadratic complexity, scaling as for sequence length . For long-context LLMs (e.g., 32k–128k tokens), this renders prefill and decoding steps both memory and compute bound. Empirically, the resulting attention maps,
display extreme sparsity (90%+ near-zero entries at large ), indicating that much of the computational work is redundant. Previous approaches using static block patterns or heuristics (Minference, MoA) lack adaptability: they are insensitive to sequence, head, or input dynamics and require manual calibration. SeerAttention’s innovation is to learn the sparsity pattern directly from the runtime attention distribution via a small parameterized gate trained by self-distillation.
2. SeerAttention Architecture and Gating Mechanism
SeerAttention inserts an "Attention Gate" in each QKV attention head to dynamically select influential attention matrix blocks. The procedure is as follows:
- Blocking: Divide the sequence into non-overlapping blocks (typically ), yielding a block grid.
- Pooling:
- : Average pooling of Q in each block.
- : Concatenation of max and min pooling over K in each block.
- Linear Gates:
- , with , where .
- Block Gating Scores:
- Sparse Mask: Select top- entries per row in (, is target sparsity) to construct a binary block mask for the downstream block-sparse attention.
The learned allows the core attention kernel to ignore inactive blocks entirely during computation, dramatically reducing cost.
3. Block-Sparse FlashAttention Kernel and Hardware Considerations
The SeerAttention block mask enables an efficient block-sparse variant of FlashAttention, implemented with custom Triton and TileLang kernels for prefill and decoding, respectively.
- Block-Sparse Prefill (Dense Batch):
- Operates on blocks selected by only, loading relevant Q/K/V segments into on-chip memory and fully avoiding computation for masked-out blocks.
- Complexity: negligible gate overhead (), approaching linearity when .
- Optimizations: Dataflow mimics FlashAttention-2 tiling; fused top- selection; warp-level softmax on NVIDIA GPUs.
- Sparse Decoding (Autoregressive, SeerAttention-R):
- Query pooling is eliminated to accommodate streaming tokens; block mask is computed using per-token queries and pooled KV blocks, processed in GQA groups for hardware efficiency.
- TileLang kernel organizes compute over a 3D grid (batch, heads_kv, block splits), iterating only over the sparse index list of selected blocks.
- On H100 GPUs, kernel achieves up to 8.6× speedup over FlashAttention-3 at 90% sparsity, with robust memory and bandwidth utilization (Gao et al., 10 Jun 2025).
4. Training Paradigm: Self-Distillation for Gate Learning
SeerAttention employs a two-stage, self-distilled gate training regime:
- Stage 1: Gate Pretraining by Distillation
- The pretrained model is frozen; full attention maps are computed via FlashAttention for a batch of sequences.
- For each block, the maximum softmax value within the block constitutes the ground truth .
- The gate learns to predict block scores that match , via a mean-squared error or KL-divergence loss.
- Only gate parameters (, and block-level RoPE) are trained; base model weights are untouched. This stage converges rapidly (∼500 steps).
- Stage 2: Sparse Fine-Tuning
- Model is fine-tuned with both the standard language modeling loss and the gate loss, plus an optional sparsity penalty for .
- Post-distillation, sparsity levels can be increased while maintaining accuracy.
Hyperparameters:
- ,
- Pretraining LR: , 500 steps, batch size 16
- Fine-tuning LR: , batch size 8, gate weight
5. Empirical Results and Comparative Analysis
SeerAttention (Prefill)
- Accuracy: On Llama-3.1-8B and Mistral-7B (32k tokens), 90% sparsity increases perplexity by only 0.16–0.17 points; at 50–60% sparsity, perplexity is nearly identical to dense attention.
- Fine-Tuning: For Llama-3-8B (8k→32k context, PG19), full attention PPL ≈ 8.79, SeerAttention (50% sparsity) ≈ 8.81, (90% sparsity) ≈ 9.16.
- Speedup: On A100, 32k tokens/90% sparsity yields 5.47× speedup over dense FlashAttention-2; 8k/50% sparsity yields 1.15×.
- Sparsity Visualization: Learned masks exhibit diverse, context-sensitive patterns (A-shapes, diagonals, hybrid), superior to fixed heuristics.
- Ablations: Average pooling for Q and max+min for K is optimal; block-RoPE in the gate is crucial for length extrapolation.
SeerAttention-R (Decoding)
- Auto-Regressive Adaptation: Query pooling is removed; gating operates per token with compressed KV blocks, supporting Grouped Query Attention (GQA).
- Accuracy: On AIME24 (4k tokens, Qwen3-14B), full attention ≈76%, SeerAttention-R ≈75%, fixed-mask Quest ≈67%. Robustness is maintained up to block sizes with minimal degradation.
- Kernel Performance: On H100, 90% sparsity and long contexts yield up to 8.6× speedup over FlashAttention-3.
- Efficiency: 0.4B tokens, 800 steps of training, yields model adaptation in 10-19 GPU hours depending on size.
- Length and Token-Budgeting: Model achieves both minimal reasoning degradation and competitive answer lengths under extreme sparsity, outperforming static-block alternatives.
6. Integration and Implementation Details
- Transformer Stack Modification: Replace standard attention call
with1
O = FlashAttention(Q, K, V)
1 2 3 4 5 6
Qp = PoolAvg(Q, B) Kp = [PoolMax(K, B); PoolMin(K, B)] GQ, GK = W_Q(Qp), W_K(Kp) S = softmax(GQ * GK.T / sqrt(r)) M = TopK_mask(S, k=floor((1-s)T)) O = BlockSparseFlash(Q, K, V, M, B) - API and Codebase: Triton and TileLang kernels, gate/training modules, and integration scripts are available at https://github.com/microsoft/SeerAttention.
- Plug-In Nature: New gate parameters are added without modifying original model weights. Integration into standard multi-head attention layers is supported for both batch and streaming settings.
7. Limitations and Future Directions
Several limitations remain:
- The gate adds a small overhead, limiting speedups for short sequences (k) or low sparsity ().
- The mechanism currently uses fixed per-head sparsity; an extension to variable per-head or per-layer sparsity is possible.
- Block size balances granularity and gate cost—a finer block size increases gate overhead.
- Decoding/incremental attention for SeerAttention (original) is unaddressed; SeerAttention-R specifically targets this with its design.
This suggests that further architectural advances in plug-in gating strategies and kernel-hardware co-design will continue to expand the practical applicability of learned sparse attention, particularly as LLM sequence lengths increase and efficiency demands grow.
References
- SeerAttention: Learning Intrinsic Sparse Attention in Your LLMs (Gao et al., 2024)
- SeerAttention-R: Sparse Attention Adaptation for Long Reasoning (Gao et al., 10 Jun 2025)