Papers
Topics
Authors
Recent
Search
2000 character limit reached

SeerAttention: Data-Driven Sparse Attention

Updated 18 March 2026
  • SeerAttention is a data-driven sparse attention mechanism that dynamically predicts block-wise sparsity using a learnable gating network to reduce LLM computational costs.
  • It partitions sequences into blocks and uses pooling with linear gates to form a binary mask, enabling efficient block-sparse FlashAttention without static heuristics.
  • Empirical evaluations show up to 5.47× speedup and minimal accuracy loss, with its autoregressive variant SeerAttention-R further optimizing decoding performance.

SeerAttention is a data-driven sparse attention mechanism for LLMs, designed to address the prohibitive computational and memory costs of dense Transformer attention on long-context sequences. By introducing a learnable gating network that predicts block-wise sparsity in the attention map, SeerAttention achieves high efficiency while maintaining model accuracy. This approach obviates the need for static heuristics or rule-based sparsity and extends naturally to both batch prefill and auto-regressive decoding settings through its decoding-optimized sibling, SeerAttention-R (Gao et al., 2024, Gao et al., 10 Jun 2025).

1. Motivation for Sparse Attention in LLMs

Transformer attention modules have quadratic complexity, scaling as O(n2)\mathcal{O}(n^2) for sequence length nn. For long-context LLMs (e.g., 32k–128k tokens), this renders prefill and decoding steps both memory and compute bound. Empirically, the resulting attention maps,

A=softmax(QK⊤/d)∈Rn×n,A = \mathrm{softmax}(QK^\top / \sqrt{d}) \in \mathbb{R}^{n \times n},

display extreme sparsity (90%+ near-zero entries at large nn), indicating that much of the computational work is redundant. Previous approaches using static block patterns or heuristics (Minference, MoA) lack adaptability: they are insensitive to sequence, head, or input dynamics and require manual calibration. SeerAttention’s innovation is to learn the sparsity pattern directly from the runtime attention distribution via a small parameterized gate trained by self-distillation.

2. SeerAttention Architecture and Gating Mechanism

SeerAttention inserts an "Attention Gate" in each QKV attention head to dynamically select influential attention matrix blocks. The procedure is as follows:

  • Blocking: Divide the sequence into T=⌈n/B⌉T = \lceil n/B \rceil non-overlapping blocks (typically B=64B=64), yielding a T×TT \times T block grid.
  • Pooling:
    • PoolQ(Q)∈RT×d\mathrm{Pool}_Q(Q) \in \mathbb{R}^{T \times d}: Average pooling of Q in each block.
    • PoolK(K)∈RT×2d\mathrm{Pool}_K(K) \in \mathbb{R}^{T \times 2d}: Concatenation of max and min pooling over K in each block.
  • Linear Gates:
    • GQ=WQ PoolQ(Q)G_Q = W_Q\,\mathrm{Pool}_Q(Q), GK=WK PoolK(K)G_K = W_K\,\mathrm{Pool}_K(K) with WQ,WK∈Rd×rW_Q, W_K \in \mathbb{R}^{d \times r}, where r≪dr \ll d.
  • Block Gating Scores:

S=softmax(GQGK⊤r)∈RT×TS = \mathrm{softmax}\left(\frac{G_Q G_K^\top}{\sqrt{r}}\right) \in \mathbb{R}^{T \times T}

  • Sparse Mask: Select top-kk entries per row in SS (k=⌊(1−s)T⌋k = \lfloor (1-s)T \rfloor, ss is target sparsity) to construct a binary block mask MM for the downstream block-sparse attention.

The learned MM allows the core attention kernel to ignore inactive blocks entirely during computation, dramatically reducing cost.

3. Block-Sparse FlashAttention Kernel and Hardware Considerations

The SeerAttention block mask MM enables an efficient block-sparse variant of FlashAttention, implemented with custom Triton and TileLang kernels for prefill and decoding, respectively.

  • Block-Sparse Prefill (Dense Batch):
    • Operates on blocks selected by MM only, loading relevant Q/K/V segments into on-chip memory and fully avoiding computation for masked-out blocks.
    • Complexity: O((1−s)n2d)+\mathcal{O}((1-s) n^2 d) + negligible gate overhead (O(ndr/B)\mathcal{O}(n d r/B)), approaching linearity when s≥0.8s \geq 0.8.
    • Optimizations: Dataflow mimics FlashAttention-2 tiling; fused top-kk selection; warp-level softmax on NVIDIA GPUs.
  • Sparse Decoding (Autoregressive, SeerAttention-R):
    • Query pooling is eliminated to accommodate streaming tokens; block mask is computed using per-token queries and pooled KV blocks, processed in GQA groups for hardware efficiency.
    • TileLang kernel organizes compute over a 3D grid (batch, heads_kv, block splits), iterating only over the sparse index list II of selected blocks.
    • On H100 GPUs, kernel achieves up to 8.6× speedup over FlashAttention-3 at 90% sparsity, with robust memory and bandwidth utilization (Gao et al., 10 Jun 2025).

4. Training Paradigm: Self-Distillation for Gate Learning

SeerAttention employs a two-stage, self-distilled gate training regime:

  • Stage 1: Gate Pretraining by Distillation
    • The pretrained model is frozen; full attention maps are computed via FlashAttention for a batch of sequences.
    • For each block, the maximum softmax value within the block constitutes the ground truth DD.
    • The gate learns to predict block scores SS that match DD, via a mean-squared error or KL-divergence loss.
    • Only gate parameters (WQ,WKW_Q, W_K, and block-level RoPE) are trained; base model weights are untouched. This stage converges rapidly (∼500 steps).
  • Stage 2: Sparse Fine-Tuning
    • Model is fine-tuned with both the standard language modeling loss and the gate loss, plus an optional sparsity penalty for ∣∣M∣∣1||M||_1.
    • Post-distillation, sparsity levels can be increased while maintaining accuracy.

Hyperparameters:

  • B=64B=64, r=64r=64
  • Pretraining LR: 10−310^{-3}, 500 steps, batch size 16
  • Fine-tuning LR: 10−510^{-5}, batch size 8, gate weight =1.0=1.0

5. Empirical Results and Comparative Analysis

SeerAttention (Prefill)

  • Accuracy: On Llama-3.1-8B and Mistral-7B (32k tokens), 90% sparsity increases perplexity by only 0.16–0.17 points; at 50–60% sparsity, perplexity is nearly identical to dense attention.
  • Fine-Tuning: For Llama-3-8B (8k→32k context, PG19), full attention PPL ≈ 8.79, SeerAttention (50% sparsity) ≈ 8.81, (90% sparsity) ≈ 9.16.
  • Speedup: On A100, 32k tokens/90% sparsity yields 5.47× speedup over dense FlashAttention-2; 8k/50% sparsity yields 1.15×.
  • Sparsity Visualization: Learned masks exhibit diverse, context-sensitive patterns (A-shapes, diagonals, hybrid), superior to fixed heuristics.
  • Ablations: Average pooling for Q and max+min for K is optimal; block-RoPE in the gate is crucial for length extrapolation.

SeerAttention-R (Decoding)

  • Auto-Regressive Adaptation: Query pooling is removed; gating operates per token with compressed KV blocks, supporting Grouped Query Attention (GQA).
  • Accuracy: On AIME24 (4k tokens, Qwen3-14B), full attention ≈76%, SeerAttention-R ≈75%, fixed-mask Quest ≈67%. Robustness is maintained up to block sizes B=128B=128 with minimal degradation.
  • Kernel Performance: On H100, 90% sparsity and long contexts yield up to 8.6× speedup over FlashAttention-3.
  • Efficiency: 0.4B tokens, 800 steps of training, yields model adaptation in 10-19 GPU hours depending on size.
  • Length and Token-Budgeting: Model achieves both minimal reasoning degradation and competitive answer lengths under extreme sparsity, outperforming static-block alternatives.

6. Integration and Implementation Details

  • Transformer Stack Modification: Replace standard attention call
    1
    
    O = FlashAttention(Q, K, V)
    with
    1
    2
    3
    4
    5
    6
    
    Qp = PoolAvg(Q, B)                  
    Kp = [PoolMax(K, B); PoolMin(K, B)]  
    GQ, GK = W_Q(Qp),  W_K(Kp)         
    S = softmax(GQ * GK.T / sqrt(r))     
    M = TopK_mask(S, k=floor((1-s)T))  
    O = BlockSparseFlash(Q, K, V, M, B)
  • API and Codebase: Triton and TileLang kernels, gate/training modules, and integration scripts are available at https://github.com/microsoft/SeerAttention.
  • Plug-In Nature: New gate parameters are added without modifying original model weights. Integration into standard multi-head attention layers is supported for both batch and streaming settings.

7. Limitations and Future Directions

Several limitations remain:

  • The gate adds a small overhead, limiting speedups for short sequences (n<2n<2k) or low sparsity (<20%<20\%).
  • The mechanism currently uses fixed per-head sparsity; an extension to variable per-head or per-layer sparsity is possible.
  • Block size B=64B=64 balances granularity and gate cost—a finer block size increases gate overhead.
  • Decoding/incremental attention for SeerAttention (original) is unaddressed; SeerAttention-R specifically targets this with its design.

This suggests that further architectural advances in plug-in gating strategies and kernel-hardware co-design will continue to expand the practical applicability of learned sparse attention, particularly as LLM sequence lengths increase and efficiency demands grow.

References

Definition Search Book Streamline Icon: https://streamlinehq.com
References (2)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to SeerAttention.