- The paper introduces MoSA, a novel attention module that replaces dense attention heads with learnable sparse experts, achieving up to 27% perplexity reduction.
- MoSA dynamically selects a small subset of tokens via an expert-choice router, reducing quadratic self-attention costs to O(k² + T) per head.
- Experimental results demonstrate improved resource efficiency and performance in language modeling across various scales and long sequence lengths.
The paper "Mixture of Sparse Attention: Content-Based Learnable Sparse Attention via Expert-Choice Routing" (2505.00315) introduces MoSA, a novel approach to address the quadratic computational and memory cost of self-attention in LLMs. MoSA replaces the dense attention heads in a standard Transformer with multiple sparse attention heads, where each head dynamically selects a small subset of tokens to attend to based on their content.
The core idea is inspired by Mixture-of-Experts (MoE) and Expert-Choice routing. Instead of every token attending to every other token (as in dense attention), each MoSA attention head acts like an "expert" that chooses which tokens from the input sequence it will process. This selection is learned end-to-end via a router mechanism. For a sequence of length T, if each head selects k tokens (k≪T), the computational complexity of that head is reduced from O(T2) to O(k2+T). The O(k2) term comes from the attention calculation on the selected tokens, and the O(T) term comes from the routing mechanism which scores all tokens to decide which k to select.
This reduction in computational cost per head allows the model to use a larger number of attention heads within the same computational budget as a dense Transformer. The authors hypothesize that having more, specialized sparse heads enables better information processing and can lead to improved performance.
Implementation Details
In MoSA, each attention head i has a standard set of projection matrices (WiQ,WiK,WiV,WiO) and an additional router weight matrix (Wir). For an input sequence X∈RT×h, the router calculates scores r=σ(XWir)∈RT, where σ is the sigmoid function. Using Expert-Choice routing, the head then selects the top k tokens based on these scores, obtaining their indices I∈{0,…,T−1}k.
The query, key, and value matrices (Q,K,V) are then computed only for the selected subset of tokens Xs∈Rk×h.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
scores = sigmoid(X @ Wr) # (T,)
topk_scores, selected_indices = TopK(scores, k) # (k,), (k,)
X_selected = X[selected_indices] # (k, h)
Q = X_selected @ WQ # (k, h')
K = X_selected @ WK # (k, h')
V = X_selected @ WV # (k, h')
Mask = compute_causal_mask(selected_indices) # (k, k)
Attention_output = softmax( (Q @ K.transpose(1, 0) / sqrt(h')) + Mask ) @ V # (k, h')
selected_router_scores = scores[selected_indices] # (k,)
Attention_output_scaled = Attention_output * selected_router_scores.unsqueeze(-1) # (k, h')
X_output_selected = Attention_output_scaled @ WO # (k, h)
Y = zeros(T, h)
Y[selected_indices] = X_output_selected # (T, h)
|
The output is then scattered back to the original sequence length, resulting in a sparse output matrix where only the selected token positions have non-zero contributions. The total MoSA layer output is the sum of the outputs from all MoSA heads. Rotary Positional Embeddings (RoPE) need to be adapted to use the original indices of the selected tokens.
Comparison with Existing Sparse Methods
Existing sparse attention methods often use static patterns (e.g., fixed stride or block-based) or content-based clustering (like the Routing Transformer [roy2021efficient]).
- Fixed Sparse Attention: Uses predefined patterns (e.g., attending to tokens at fixed intervals). It's efficient but not adaptive to input content.
- Routing Transformer: Clusters tokens within a head using K-means and attends within clusters. It's content-based but relies on clustering, which can have convergence issues and requires computing Q/K for all tokens before clustering.
MoSA's key advantage is its learned, content-based selection per head, allowing for flexible and diverse sparsity patterns. Unlike the Routing Transformer, it avoids clustering and calculates Q, K, V only for the selected tokens, leading to greater efficiency, especially in the projection steps. The selection mechanism is directly optimized by the LLMing objective.
Experimental Results
The paper evaluates MoSA on LLMing across different model scales (Tiny, Small, Medium, Large) on the C4 dataset.
- IsoFLOP Performance: In a setting where the total computational budget (FLOPs) is matched to a dense Transformer baseline, hybrid models (using 4 dense heads and a variable number of MoSA heads to match FLOPs) consistently outperform the dense baseline and other sparse methods (Fixed, Routing Transformer). MoSA reduces perplexity by up to 27% compared to the dense baseline. Performance improves as sparsity increases (using more sparse heads) up to a certain point (around sparsity 64 for T=1024), after which it declines. This demonstrates that the saved computation can be effectively used by adding more specialized sparse heads.
- Resource Efficiency: When trained to match the perplexity of a dense baseline, MoSA models demonstrate significant practical efficiency gains without specialized CUDA kernels. They show reduced wall-clock time per step, lower GPU memory usage, and a drastic reduction in KV-cache size (over 50%). This highlights MoSA's ability to achieve the same performance with fewer resources.
- Long Sequences: On longer sequences (up to T=8192) combined with local attention (which is standard practice for long-context models), MoSA also outperforms other sparse attention methods, even when using significantly fewer FLOPs than the Routing Transformer on the longest sequences.
- Downstream Tasks: MoSA generally performs well on zero-shot downstream tasks, often outperforming other sparse attention methods and sometimes the dense baseline. However, a performance drop is observed on tasks with very short sequence lengths (like BLiMP). This is attributed to the distribution mismatch where the token selection mechanism trained on long sequences is applied to very short inputs, forcing heads to select a large percentage of tokens they might not specialize in.
Limitations and Future Work
- Non-Autoregressive Routing: The TopK selection in MoSA is non-autoregressive, meaning it requires access to all tokens in the sequence to make a selection. This makes direct application to autoregressive inference challenging, similar to other Expert-Choice methods. Adapting MoD's approach of training an autoregressive classifier could be a future direction.
- Downstream Performance on Short Sequences: The observed performance drop on very short sequences needs mitigation. Potential solutions include training with truncated sequences or incorporating instruction tuning, which has been shown to improve MoE performance on downstream tasks.
- Further Optimizations: Developing specialized CUDA kernels for the sparse operations could further improve efficiency beyond the current PyTorch implementation. Combining MoSA with other efficient techniques like MQA, GQA, or SwitchHead is another avenue.
Conclusion
MoSA offers a practical and effective approach to mitigating the quadratic cost of self-attention by introducing learned, content-based sparsity via Expert-Choice routing. It demonstrates significant improvements in perplexity and resource efficiency across various scales and sequence lengths compared to traditional dense attention and other sparse methods. While challenges remain, particularly regarding autoregressive inference and performance on very short sequences, MoSA represents a promising direction for building more efficient and capable LLMs.