Fast and Simplex: 2-Simplicial Attention in Triton
(2507.02754v1)
Published 3 Jul 2025 in cs.LG and cs.AI
Abstract: Recent work has shown that training loss scales as a power law with both model size and the number of tokens, and that achieving compute-optimal models requires scaling model size and token count together. However, these scaling laws assume an infinite supply of data and apply primarily in compute-bound settings. As modern LLMs increasingly rely on massive internet-scale datasets, the assumption that they are compute-bound is becoming less valid. This shift highlights the need for architectures that prioritize token efficiency. In this work, we investigate the use of the 2-simplicial Transformer, an architecture that generalizes standard dot-product attention to trilinear functions through an efficient Triton kernel implementation. We demonstrate that the 2-simplicial Transformer achieves better token efficiency than standard Transformers: for a fixed token budget, similarly sized models outperform their dot-product counterparts on tasks involving mathematics, coding, reasoning, and logic. We quantify these gains by demonstrating that $2$-simplicial attention changes the exponent in the scaling laws for knowledge and reasoning tasks compared to dot product attention.
Summary
The paper introduces a trilinear 2-simplicial attention mechanism that improves scaling exponents by up to 20.2% on benchmarks like GSM8k and MMLU.
It employs sliding window locality and grouped query attention using custom Triton kernels to mitigate the O(n³) complexity of higher-order interactions.
Empirical results on MoE models show that 2-simplicial attention enhances token efficiency and outperforms standard Transformers on reasoning, math, and coding tasks.
2-Simplicial Attention: Scaling Laws and Efficient Implementation in Triton
The paper "Fast and Simplex: 2-Simplicial Attention in Triton" (2507.02754) presents a comprehensive paper of 2-simplicial attention as a generalization of standard dot-product attention, with a focus on both theoretical scaling properties and practical implementation. The authors demonstrate that 2-simplicial attention, when efficiently implemented, can yield improved token efficiency and more favorable scaling exponents for reasoning, mathematics, and coding tasks, particularly under token-constrained regimes.
Theoretical Contributions
The core theoretical advancement is the extension of attention from bilinear (dot-product) to trilinear forms, moving from 1-simplex (edges) to 2-simplex (triangles) in the attention mechanism. This is formalized as:
Standard Attention:Aij=⟨Qi,Kj⟩/d
2-Simplicial Attention:Aijk=⟨Qi,Kj,Kk′⟩/d
The trilinear form enables each query to attend over pairs of keys, capturing higher-order interactions. The authors further explore rotation-invariant trilinear forms using determinants, which are relevant for generalizing positional encodings such as RoPE.
A key empirical finding is that 2-simplicial attention increases the scaling law exponent α in the loss–parameter relationship L(N)=E′+A/Nα, compared to standard Transformers. This is quantified across several benchmarks (GSM8k, MMLU, MMLU-pro, MBPP), with improvements in α ranging from 6.8% to 20.2%. The improvement is most pronounced on reasoning-heavy and less saturated tasks.
Practical Implementation in Triton
Implementing 2-simplicial attention naively incurs O(n3) complexity, which is prohibitive for long sequences. The authors address this by:
Sliding Window Locality: Restricting attention to O(nw1w2) by having each query attend to a local window of w1 and w2 keys, analogous to local attention in standard Transformers.
Triton Kernels: Custom forward and backward kernels are developed in Triton, leveraging 2D tiling and online softmax. The forward pass fuses elementwise and matmul operations to maximize CUDA and Tensor Core utilization. The backward pass is split into two kernels to avoid atomic operation overhead, with further optimizations for small w2.
deftwo_simplicial_attention(Q, K, V, Kp, Vp, w1, w2):
# Q: [batch, seq, heads, dim]# K, Kp: [batch, seq, heads, dim]# V, Vp: [batch, seq, heads, dim]# Sliding window over w1, w2for i inrange(seq):
for j inrange(max(0, i-w1), i+1):
for k inrange(max(0, i-w2), i+1):
logits = trilinear(Q[i], K[j], Kp[k])
# Apply causal mask and softmax over (j, k)# Aggregate output: sum_{j,k} softmax(logits) * (V[j] * Vp[k])
Triton Kernel Considerations
Memory Layout: Efficient tiling along the query and key axes is critical for throughput.
Online Softmax: Reduces memory overhead and enables streaming computation.
Backward Pass: Decomposed to avoid atomics, with two-stage computation for dQ, dK, dK′, dV, dV′.
Latency: With window sizes such as (512, 32), the latency is comparable to standard attention at long context lengths (e.g., 110 ms for 32k tokens).
Empirical Results
The authors train MoE models with up to 3.5B active parameters and 176B total parameters, interleaving 2-simplicial attention every fourth layer. On fixed token budgets, 2-simplicial models outperform standard Transformers on math, reasoning, and coding tasks, with the performance gap widening as model size increases.
Model
Active Params
GSM8k (NLL)
MMLU (NLL)
MMLU-pro (NLL)
MBPP (NLL)
Transformer
~3.5B
0.2781
0.5543
0.7858
0.2203
2-simplicial
~3.5B
0.2718
0.5484
0.7689
0.2193
Δ (%)
-2.27%
-1.06%
-2.15%
-0.45%
The scaling exponent α is consistently higher for 2-simplicial attention, indicating more efficient parameter utilization under token constraints.
Implications and Future Directions
Practical Implications:
Token Efficiency: 2-simplicial attention is particularly advantageous when high-quality data is scarce, as it enables better scaling with limited tokens.
Reasoning Tasks: The architecture is especially beneficial for tasks requiring higher-order interactions, such as logic, mathematics, and code generation.
Hardware Co-design: While the Triton implementation is performant for prototyping, further optimization (e.g., via CUTLASS or custom hardware) is necessary for production-scale deployment.
Theoretical Implications:
Scaling Laws: The results challenge the prevailing view that architectural changes do not affect scaling exponents, providing evidence that higher-order attention can fundamentally alter scaling behavior.
Expressivity: 2-simplicial attention expands the class of functions representable by a single layer, as formalized in the Match3 problem and related VC-dimension arguments.
Future Work:
Production-Ready Kernels: Co-designing kernels for specific accelerators to further reduce latency and memory overhead.
Hybrid Architectures: Exploring combinations of 2-simplicial and standard attention, or integrating with other efficient attention mechanisms.
Scaling to Larger Models: Investigating the behavior of 2-simplicial attention at even larger scales and in multi-modal or multi-task settings.
Conclusion
This work demonstrates that 2-simplicial attention, when efficiently implemented, provides a viable path to improved scaling and token efficiency for reasoning-intensive tasks. The combination of theoretical analysis, empirical validation, and practical kernel engineering establishes a foundation for further exploration of higher-order attention mechanisms in large-scale LLMs.