SeerAttention-R: Sparse Attention Adaptation for Long Reasoning (2506.08889v1)
Abstract: We introduce SeerAttention-R, a sparse attention framework specifically tailored for the long decoding of reasoning models. Extended from SeerAttention, SeerAttention-R retains the design of learning attention sparsity through a self-distilled gating mechanism, while removing query pooling to accommodate auto-regressive decoding. With a lightweight plug-in gating, SeerAttention-R is flexible and can be easily integrated into existing pretrained model without modifying the original parameters. We demonstrate that SeerAttention-R, trained on just 0.4B tokens, maintains near-lossless reasoning accuracy with 4K token budget in AIME benchmark under large sparse attention block sizes (64/128). Using TileLang, we develop a highly optimized sparse decoding kernel that achieves near-theoretical speedups of up to 9x over FlashAttention-3 on H100 GPU at 90% sparsity. Code is available at: https://github.com/microsoft/SeerAttention.
Summary
- The paper demonstrates a novel sparse attention adaptation that reduces the quadratic cost of full attention during long autoregressive decoding.
- It introduces a lightweight AttnGate module trained via self-distillation to predict key token block scores efficiently.
- Empirical results confirm near-lossless accuracy with token budgets as low as 2k and substantial speedups using a custom sparse decoding kernel.
The paper "SeerAttention-R: Sparse Attention Adaptation for Long Reasoning" (2506.08889) introduces a sparse attention framework specifically designed to improve the efficiency of long decoding sequences in LLMs, particularly those focused on reasoning tasks. Building upon the prior SeerAttention work (2410.13276), SeerAttention-R maintains the core idea of learning attention sparsity through a self-distilled gating mechanism but adapts it for the auto-regressive nature of decoding.
The key challenge addressed is the increasing computational and memory cost during long decoding, where attention calculation and KV cache size grow quadratically with sequence length. Sparse attention aims to mitigate this by attending to only a subset of important tokens. The paper shows empirically that even in reasoning models, attention is inherently sparse, meaning only a fraction of tokens are critical for maintaining performance (Section 2.2, Figure 2). The difficulty lies in identifying these important tokens efficiently.
SeerAttention-R tackles this with a lightweight, plug-in Attention Gate (AttnGate) that can be added to existing pre-trained Transformer models without modifying their original weights. This makes it a post-training adaptation method.
Architecture and Mechanism:
The AttnGate in SeerAttention-R differs from its SeerAttention predecessor primarily to support auto-regressive decoding. It removes sequence-level pooling on the Query (Q) tensor, processing each query token individually. To align with Grouped Query Attention (GQA) (2305.13245), commonly used in modern LLMs to reduce KV cache size, the Q branch of the AttnGate uses a linear layer to project query heads within a GQA group down to a single head. This enables a shared sparsity decision across the group, enhancing hardware efficiency.
For the Key (K) tensor, SeerAttention-R retains pooling-based compression along the sequence dimension, similar to the original SeerAttention. It uses a combination of Max, Min, and Average pooling over blocks of tokens, concatenating their outputs before a linear layer. This pooling aims to capture diverse information within a block.
The AttnGate calculates block-level scores (S) based on the processed Q and K tensors, similar to standard attention computation (RoPE (2405.04517), matrix multiplication, softmax): \begin{subequations} \label{eq:gate_summary} \begin{align} \mathbf{Q_{gate} &= \mathrm{RoPE}\Bigl( \mathbf{W_{gate}q} \ \operatorname{reshape}(\mathbf{Q_{nope}, [..., g\cdot d]) \Bigr) \ \mathbf{K_{gate} &= \mathrm{RoPE} \Bigl( \mathbf{W_{gate}k} \ \operatorname{concat}[\operatorname{P_{max}(\mathbf{K_{nope}),\operatorname{P_{min}(\mathbf{K_{nope}),\operatorname{P_{avg}(\mathbf{K_{nope})] \Bigr) \ \mathbf{S} &= \operatorname{softmax}( \mathbf{Q_{gate}\,\mathbf{K_{gate}{!\top}/\sqrt{d_{gate} ). \end{align} \end{subequations} where Qnope and Knope are the query and key tensors before RoPE, g is the GQA group size, d is the hidden dimension per head, and dgate is the AttnGate's head dimension.
Training (Distillation):
SeerAttention-R trains the AttnGate using a self-distillation process, where the original full-attention model serves as the teacher. The goal is to train the AttnGate to predict the block-level attention scores of the teacher model. For decoding, the ground truth is generated by performing 1D maxpooling column-wise on the teacher's attention map (Figure 2a). To accommodate shared sparsity in GQA, this column-pooled map is further maxpooled within each query head subgroup. The AttnGate is trained using Kullback-Leibler divergence (1101.2011) loss against this normalized ground truth.
A key aspect of the training is efficiency. The original model weights are frozen. The paper proposes an efficient kernel that directly generates the ground truth and attention output during training, avoiding the explicit calculation of the full attention map which is memory-intensive (Figure 2b). Training is lightweight, requiring only 0.4B tokens from a dataset like OpenR1-MATH-220k (2501.12948) and relatively few GPU hours (e.g., ~12 GPU hours for a 8B model on MI300x) (Table 2).
Inference:
During inference, the AttnGate outputs (S) are converted into binary block masks or indices to select important KV blocks. Two methods are discussed:
- Token Budget: Select the Top-k blocks based on their scores, where k is derived from a predefined token budget (e.g., 4k, 8k).
- Threshold: Select blocks whose scores exceed a fixed threshold. This method is simpler but might result in variable sparsity ratios.
To speed up AttnGate prediction during inference, a K Compression Cache is used to store the compressed K representation. This cache is updated only once every block size number of tokens are generated, minimizing AttnGate overhead (Figure 3). The last block is always activated when the cache is not fully updated to prevent accuracy loss. The K Compression Cache is significantly smaller than the main KV cache (<1% overhead for block size 64), enabling potential KV cache offloading strategies.
A specialized block sparse flash decoding kernel was developed using TileLang [misc:tilelang] and Triton (1907.10168). This kernel efficiently processes the selected sparse KV blocks, leveraging GQA's structure and optimizing memory access and computation on GPUs like H100 (Section 3.3, Figure 4).
Experimental Results:
The paper evaluates SeerAttention-R on reasoning benchmarks (AIME24, AIME25 [misc:aime], MATH-500 (2009.03300), GPQA-Diamond (2401.15025)) using Qwen3 (2505.09388) and DeepSeek-R1 (2501.12948) models. It compares SeerAttention-R against full attention and Quest (2406.10774), a training-free heuristic sparse attention method.
Key findings include:
- Oracle Sparsity: Attention is indeed sparse in reasoning models; using oracle sparsity, near-lossless accuracy is achievable with token budgets as low as 2k, especially with smaller block sizes (32, 64) (Figure 2).
- Accuracy: SeerAttention-R consistently outperforms Quest and maintains near-lossless accuracy with 4k token budgets, particularly on larger models (14B variants) which are more tolerant to sparsity (Figure 5). The accuracy gap between SeerAttention-R and the dense baseline is smaller than that for Quest, especially with larger block sizes.
- Kernel Speedup: The custom block sparse kernel, especially the TileLang implementation, achieves significant speedups over FlashAttention-3 (2404.15792) (up to ~9x at 90% sparsity) and the Triton implementation, particularly at larger sequence lengths and batch sizes where I/O becomes the bottleneck (Figure 6).
- Block Size: SeerAttention-R is more robust to larger block sizes (64, 128) compared to Quest, which sees accuracy degradation. Larger block sizes are generally more hardware efficient (Figure 7).
- Hybrid Layers: While Quest benefits significantly from keeping the first few layers dense, SeerAttention-R sees only marginal gains, suggesting its learned sparsity is effective even in early layers (Figure 8).
- Sparsity Method: Token budget and threshold methods show different activated token distributions, with the threshold method exhibiting slightly better sparsity-accuracy trade-off in high sparsity regions (Figure 9).
- Generate Length: Inaccurate sparse attention (Quest, or SeerAttention-R with too small budget) can lead to longer reasoning paths, potentially undermining efficiency gains. Accurate sparsity selection is crucial (Table 1).
Limitations and Future Work:
The paper notes that achieving full end-to-end system speedup requires integration with inference frameworks (vLLM (2312.06177), SGLang (2407.02854), Lserve (2502.14866)) and supporting technologies like PagedAttention (2312.06177) and KV cache offloading (2402.04617). Determining optimal and adaptive sparsity ratios dynamically is also an open challenge. Finally, unifying sparse attention for both prefill and decoding phases, currently handled separately by SeerAttention and SeerAttention-R, remains an important direction for future research. Techniques like multi-token prediction or speculative decoding could potentially facilitate this unification by introducing query-level parallelism during decoding.