Papers
Topics
Authors
Recent
Search
2000 character limit reached

MiniMax Sparse Attention (MSA)

Updated 12 June 2026
  • MiniMax Sparse Attention (MSA) is a blockwise sparse softmax mechanism that restricts dot-product computations to selected key–value blocks, reducing quadratic cost.
  • Its two-branch design uses lightweight index queries and exact main branch softmax, achieving up to 28.4× FLOP reduction at million-token contexts.
  • Optimized GPU kernel co-design and a KL divergence alignment loss ensure minimal loss in fidelity while supporting tasks like code reasoning and multimodal inference.

MiniMax Sparse Attention (MSA) is a blockwise sparse softmax attention mechanism designed for large transformer models to efficiently handle ultra-long context windows ranging from hundreds of thousands to millions of tokens. MSA substantially reduces the quadratic computational cost of standard softmax attention by restricting exact dot-product computations to a small, content-adaptive subset of key–value blocks, enabling scalable deployment for agentic workflows, code reasoning, persistent memory, and multimodal tasks without significant loss in model expressivity or accuracy (Lai et al., 11 Jun 2026).

1. Motivation and Context

The escalating demand for ultra-long-context capabilities in LLMs is driven by use cases such as agentic workflows, repository-scale code reasoning, and persistent dialogue memory. Standard softmax attention computes an N×NN \times N similarity matrix for a sequence of length NN, incurring prohibitive Θ(N2)\Theta(N^2) cost—both in computation and memory—for million-token contexts. Conventional approximations, such as low-rank or linear attention, either degrade modeling fidelity or impose constraints on head or representation dimensions. The core motivation behind MSA is to overcome these bottlenecks by sparsifying attention adaptively, allowing the model to retain exact softmax computations but only for the set of tokens deemed relevant for each query.

2. Architecture and Mechanism

MSA is constructed atop Grouped Query Attention (GQA), where HqH_q query heads are partitioned into HkvH_{kv} groups, each with shared key–value projections, resulting in G=Hq/HkvG = H_q/H_{kv} heads per group. Each MSA layer employs a two-branch structure:

  • Index Branch: For each query group, lightweight didxd_{\text{idx}}-dimensional projections (index queries and keys) are computed. For every position ii and group rr, block-pooled scores are calculated over all previously seen key blocks, after which the top-kk scoring blocks (plus the local block) are selected as the candidate set for sparse attention.
  • Main Branch: For each query head within group NN0, only the key–value vectors from the NN1 selected blocks are assembled, and an exact softmax attention is performed over these tokens.

The blockwise structure operates on blocks of size NN2; at each layer, the selection and computation are performed at this granularity. GQA provides natural granularity for block-level sparse selection, drastically reducing both required memory bandwidth and compute.

3. Mathematical Formulation

Let a sequence of length NN3 be split into NN4 blocks, with the NN5th block denoted NN6. For each query position NN7 and group NN8:

  • The index queries NN9 and a single shared index key Θ(N2)\Theta(N^2)0 are projected from the hidden states.
  • The token-level index score for Θ(N2)\Theta(N^2)1 is

Θ(N2)\Theta(N^2)2

  • Block scores are calculated by pooling:

Θ(N2)\Theta(N^2)3

Blocks with no eligible Θ(N2)\Theta(N^2)4 receive Θ(N2)\Theta(N^2)5.

  • Top-Θ(N2)\Theta(N^2)6 block selection for Θ(N2)\Theta(N^2)7 is performed as:

Θ(N2)\Theta(N^2)8

always including the local block containing Θ(N2)\Theta(N^2)9.

  • Main Branch sparse attention for query head HqH_q0 in group HqH_q1 is computed as:

HqH_q2

Because HqH_q3, the cost per query scales as HqH_q4 rather than HqH_q5.

To align the Index Branch selection with the Main Branch's actual attention distribution, an auxiliary KL divergence loss is minimized between the indexer's softmax over the selected tokens and the Main Branch softmax, with gradients detached from the latter.

4. GPU Kernel Co-Design and Execution Path

Translating theoretical complexity reductions to actual hardware speedups necessitates GPU-specific optimizations:

  • Exp-Free Top-HqH_q6: The top-HqH_q7 blocks are selected directly using the monotonicity of softmax scores, eliminating exponentiation. A warp-based min-heap and a parallel merge implement this in HqH_q8 time.
  • KV-Outer Sparse Kernels: Rather than iterating over queries, kernels are implemented to iterate over selected key–value blocks, gathering all querying positions that reference a block and launching large tensor-core matrix-multiplies. This improves arithmetic intensity and overcomes low-utilization patterns of query-outer layouts.
  • Two-Phase Softmax + Split-K Combine: Since each query gathers HqH_q9 partial outputs from different thread blocks (CTAs), local logits and partial sums are stored in HBM buffers. A combine kernel then computes the global LogSumExp and normalizes results, aggregating the HkvH_{kv}0 outputs.
  • Load Balancing and Chunking: "Sink" blocks popular across queries are duplicated across CTAs; each output write is mapped to a preassigned slot, removing the need for atomic operations.

These strategies ensure that block-granular execution efficiently utilizes GPU tensor-cores, with design tailored for hardware such as NVIDIA H800.

5. Computational Complexity and Efficiency

Let HkvH_{kv}1 be the context length, HkvH_{kv}2 the number of query heads, HkvH_{kv}3 key–value groups, HkvH_{kv}4 the head dimension, HkvH_{kv}5 the index projection size, HkvH_{kv}6 the block size, and HkvH_{kv}7 the number of selected blocks:

  • Full GQA Attention:

HkvH_{kv}8

  • MSA FLOP Cost:

HkvH_{kv}9

The first term (Index Branch) maintains G=Hq/HkvG = H_q/H_{kv}0 scaling but with G=Hq/HkvG = H_q/H_{kv}1; the dominant cost for large G=Hq/HkvG = H_q/H_{kv}2 becomes the linear-in-G=Hq/HkvG = H_q/H_{kv}3 Main Branch, since G=Hq/HkvG = H_q/H_{kv}4.

For model hyperparameters G=Hq/HkvG = H_q/H_{kv}5, G=Hq/HkvG = H_q/H_{kv}6, G=Hq/HkvG = H_q/H_{kv}7, G=Hq/HkvG = H_q/H_{kv}8, G=Hq/HkvG = H_q/H_{kv}9, didxd_{\text{idx}}0, MSA reduces the per-token attention FLOPs by didxd_{\text{idx}}1 at didxd_{\text{idx}}2. Empirically, with the fused GPU kernels, prefill speedup is didxd_{\text{idx}}3 and decoding speedup is didxd_{\text{idx}}4 at million-token context lengths (NVIDIA H800) (Lai et al., 11 Jun 2026).

6. Empirical Performance and Benchmarks

MSA was evaluated on a 109B-parameter Mixture-of-Experts model with native multimodal pretraining over 3 trillion tokens:

  • Pretraining Quality: MSA, both trained from scratch (MSA-PT) and by continued pretraining (MSA-CPT), matched or slightly outperformed GQA full attention across tasks such as MMLU, BBH, code, math, vision, video, and long-context RULER/HELMET benchmarks. No significant perplexity degradation was observed.
  • Long-Context Retention: Post-extension by 140B long-context tokens, MSA-CPT retained results within didxd_{\text{idx}}5 points on HELMET-128K and didxd_{\text{idx}}6 points on RULER-128K relative to full attention.
  • Efficiency Gains: Achieved measured didxd_{\text{idx}}7 prefill and didxd_{\text{idx}}8 decode wall-clock speedups on NVIDIA H800 at 1M tokens, with additional reductions in memory footprint due to not materializing the full didxd_{\text{idx}}9 key–value cache per head.

A production-grade, natively multimodal MSA-powered model (109B parameters) has been released: https://huggingface.co/MiniMaxAI/MiniMax-M3. Inference kernels are available at: https://github.com/MiniMax-AI/MSA.

7. Integration and Deployment Considerations

MSA is deliberately minimal and modular. To insert MSA into an existing transformer architecture:

  • Add two small Index Branch projections per layer (for ii0).
  • Compute block-max-pooled scores for the Index Branch, followed by exp-free Top-ii1 block selection.
  • Route the conventional Matrix-Multiply–Softmax–GEMM computation path over the gathered sparse key–value subset via a sparse gather operation.
  • During training, incorporate the KL-alignment loss with gradient detachment from the Main Branch softmax, and warm up with dense attention before activating sparsity.

This streamlined structure enables seamless adoption in any model supporting GQA or multi-query attention mechanisms, facilitating the deployment of exact softmax attention even in million-token contexts with minimal regressions in model utility and substantial acceleration (Lai et al., 11 Jun 2026).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to MiniMax Sparse Attention (MSA).