MHLA: Restoring Expressivity of Linear Attention via Token-Level Multi-Head
Published 12 Jan 2026 in cs.CV and cs.AI | (2601.07832v1)
Abstract: While the Transformer architecture dominates many fields, its quadratic self-attention complexity hinders its use in large-scale applications. Linear attention offers an efficient alternative, but its direct application often degrades performance, with existing fixes typically re-introducing computational overhead through extra modules (e.g., depthwise separable convolution) that defeat the original purpose. In this work, we identify a key failure mode in these methods: global context collapse, where the model loses representational diversity. To address this, we propose Multi-Head Linear Attention (MHLA), which preserves this diversity by computing attention within divided heads along the token dimension. We prove that MHLA maintains linear complexity while recovering much of the expressive power of softmax attention, and verify its effectiveness across multiple domains, achieving a 3.6\% improvement on ImageNet classification, a 6.3\% gain on NLP, a 12.6\% improvement on image generation, and a 41\% enhancement on video generation under the same time complexity.
The paper introduces MHLA, a method that restores query-conditioned expressivity in linear attention by mixing local key-value summaries across token blocks.
MHLA achieves superior performance in image classification, generation, and language modeling, improving metrics such as Top-1 accuracy, FID, and LongBench scores.
By maintaining linear complexity with learned mixing coefficients, MHLA overcomes global context collapse to enable scalable, granular token interaction.
Multi-Head Linear Attention: Reinstating Expressivity under Linear Complexity
Motivation for Linear Attention and Its Deficiencies
Transformer self-attention, while foundational across domains such as vision, NLP, and generative modeling, is bottlenecked by quadratic time and memory complexity with respect to sequence length N. This limits scaling to high-resolution or long-horizon tasks. Linear attention mechanisms reduce this complexity by kernelizing the softmax (i.e., replacing the exponential kernel with a feature map ϕ). However, this kernelization aggregates all keys and values into a single global summary, yielding a drastic loss of query-conditioned selectivity and representational diversity—a phenomenon termed "global context collapse." As N increases, the fixed-size summary saturates, capping both the rank and sparsity of the attention matrix and resulting in degraded performance, most notably for tasks demanding nuanced, token-level relationships.
MHLA: Formulation and Theoretical Properties
To address the aforementioned deficiencies, the paper introduces Multi-Head Linear Attention (MHLA), which restores expressive capacity without recourse to convolutions, gating, or quadratic fallback modules. MHLA divides the token dimension into non-overlapping blocks (termed "heads"), computes independent local key–value summaries per block, and introduces a learnable multi-head mixing mechanism: each query block computes a query-dependent mixture over all local summaries via block-specific, normalized, learnable coefficients. This approach is illustrated in Figure 1.
Figure 1: MHLA splits tokens into multiple heads along the sequence, then mixes KV summaries per query block, recovering query-conditioned selectivity and token-level diversity with linear complexity.
where Sb​ and zb​ are the local summarized KV pairs and mixed normalizers for block b and mi,b​ are learnable mixing coefficients, subject to nonnegativity and normalization constraints. This two-stage mixture—first at the block level, then intra-block via inner products—restores per-token adaptive weighting and markedly increases the achievable attention matrix rank compared to standard linear attention.
Strong theoretical properties include:
Asymptotic complexity remains O(Nd2+M2d2), with M2≪N feasible in practice.
The maximal attainable rank of the attention matrix is min(n,∑b=1M​min(nb​,d)), growing additively with the number of heads, thereby resisting the diversity bottleneck.
Query-conditioned sparsity is empirically higher than for linear attention and commonly exceeds that of softmax-based schemes in entropy-based metrics.
Experimental Analysis: Discriminative and Generative Domains
MHLA's expressivity and efficiency are validated across multiple tasks:
Image Classification: Integrated into DeiT and VLT architectures, MHLA delivers superior accuracy under identical computational budgets. For DeiT-Tiny, Top-1 accuracy rises from 72.2% (self-attention) and 69.8% (linear attention) to 75.8% (MHLA), matching or surpassing recent SOTA linear and hybrid attention variants with minimal parameter overhead.
Image Generation: Within the DiT and DiG class-to-image frameworks, MHLA consistently achieves the lowest FID at each scale. For DiT-XL/2 (256px), FID drops to 19.17 (MHLA w/ CPE+gating), outperforming vanilla self-attention (19.47) and significantly surpassing kernelized baselines (28.63). This is accomplished while approximately doubling throughput relative to the softmax baseline at 512px resolution. These results are achieved solely through attention re-design; no auxiliary convolution or self-attention modules are introduced for expressive compensation.
Text-to-Image Synthesis: Finetuning SANA with MHLA yields FID 5.90, outperforming the SANA baseline (6.10), Pixart-α (6.14), and Pixart-Σ (6.34). Loss convergence is faster, indicating favorable gradient dynamics and rapid adaptation, as visualized in Figure 2.
Figure 2: SANA-MHLA generation results, indicating high sample fidelity and diversity in text-to-image tasks.
Video Generation: When applied to Wan2.1-1.3B for video diffusion (sequences up to 31,500 tokens), MHLA matches FlashAttention’s quality (Total: 82.62 vs. 83.31) at 2.1x lower inference latency and vastly outperforms vanilla linear attention, which suffers severe context collapse.
Autoregressive Language Modeling: In 0.3–0.35B scale models trained on FineWeb-Edu, MHLA is on-par with or modestly superior to SOTA linearized and SSM architectures (Mamba, GDN, GLA, Mamba2) on perplexity, MMLU, and LongBench. It achieves the highest average LongBench score (7.41), demonstrating superior context utilization in long-sequence NLP.
Ablation Studies and Structural Insights
Ablations confirm that locality-biased initialization for the mixing coefficients accelerates and stabilizes learning, with additional gains from allowing adaptive learning of these coefficients. The token head number M influences the trade-off between expressivity and runtime; linear complexity and optimal FID are achieved with moderate M (M2<N). MHLA’s design demonstrates that per-block mixing suffices for expressive recovery, and the empirical entropy and rank analyses show substantial restoration of both token-level diversity and selectivity.
Practical and Theoretical Implications
MHLA advances the state of linear attention by providing a theoretically principled and empirically validated mechanism for restoring softmax-like expressivity. It achieves this without auxiliary computational costs or partial re-introduction of O(N2) terms, presenting a direct path to efficient scaling in domains where quadratic attention is infeasible.
On the practical axis, MHLA’s ability to drop into established backbones (DeiT, DiT, SANA, Wan) and immediately improve accuracy, generative quality, and training stability, while maintaining (or even improving) hardware efficiency, establishes a new baseline for long sequence, high-resolution, and multimodal Transformer research.
On the theoretical axis, MHLA directly addresses the "global context collapse" limitation, demonstrating that the core bottleneck was not just the absence of depth or architectural trickery, but a structural lack of query-conditioned, block-wise interaction in global kernel aggregation schemes. The methodology points toward further generalizations, such as adaptive or hierarchical block partitions, multi-scale block mixing, or integration with SSMs/RNNs for extended temporal reasoning.
Conclusion
MHLA represents a principled and effective means of reinstating query-conditioned expressivity to linear-time attention, enabling Transformer architectures to scale efficiently without sacrificing the qualitative modeling benefits of full self-attention. This work lays the groundwork for future research into more granular token-block interaction mechanisms, as well as for their adoption in high-throughput computer vision, sequence generation, and long-horizon multimodal tasks.