Tile-Aware Token Rounding in Sparse MoE Models
- Tile-aware token rounding is an algorithm that optimizes token-to-expert assignments by aligning counts with hardware tile sizes to eliminate GEMM padding waste.
- It employs a deterministic rounding scheme that minimally perturbs top-K routing, ensuring both throughput improvements and minimal distributional shift.
- Empirical results show up to 1.16× speedup and notable TFLOPS gains on GPUs in high sparsity MoE models, with no significant loss in output quality.
Tile-aware token rounding is a specialized algorithmic strategy designed to eliminate computational inefficiencies arising from the interaction of sparse token routing and hardware-implemented matrix tiling in Mixture-of-Experts (MoE) architectures, particularly during the grouped GEMM operations on GPUs. By minimally perturbing token-to-expert assignments to conform to hardware tile multiples, this method ensures maximal arithmetic intensity, mitigates memory bloat from zero-padding, and preserves or improves statistical fidelity in model outputs. The concept has precise analogues in geometric rounding theory and connects to optimal deterministic partition schemes for error minimization under tiling constraints (Guo et al., 16 Dec 2025, Woude et al., 2022).
1. Motivation and Problem Formulation
In fine-grained and highly sparse MoE layers, each input token is routed to a subset of selected experts, resulting in variable and often small per-expert batch sizes . Grouped GEMM kernels on modern accelerators, such as NVIDIA Hopper/H100, require these batches (the M-dimension of the matrix) to be padded to a fixed tile size (e.g., 128 rows), leading to "tile-quantization" waste: superfluous floating-point operations on zero-padded rows that do not contribute to the model output. As the average approaches the tile size, the fraction of wasted compute increases, with padding waste comprising up to 20–30% of GEMM FLOPs at high sparsity regimes (e.g., ). Tile-aware token rounding ("TR") eliminates this waste by aligning to tile multiples, thereby improving throughput (Guo et al., 16 Dec 2025).
2. Mathematical and Algorithmic Foundations
Let be the microbatch size, the total number of experts, the count of active experts per token, and the GEMM tile size. For each expert , let denote its allocated token count under vanilla top- routing, with nearest tile multiples and . The tile-aware rounding problem is:
This reduces to an optimization minimizing or deviation subject to tile-multiplicity and global token-count preservation:
Routing reallocation is performed by reordering tokens per expert according to modified router scores, promoting original top- tokens and demoting others by a fixed margin, then selecting the top tokens per expert (Guo et al., 16 Dec 2025).
3. Algorithm Implementation and Complexity
The standard TR algorithm consists of:
- Initial Top-K Routing: For each token , select its highest-scoring experts, recording both the scores and indices.
- Per-Expert Tile Multiplicity Calculation: For each expert, count routed tokens and compute and .
- Expert-wise Token Re-ranking: Construct per-expert score lists that preserve top-K assignments but penalize demoted tokens. Sort each expert's token list.
- Rounding Pattern Selection: For each expert, apply a deterministic or stochastic rounding subroutine (nearest, stochastic, or balanced) to pick , choosing the one that both minimizes deviation and preserves the global token count.
- Routing Mask Reconstruction: Assign the top tokens by the sorted list, ensuring the per-expert batch is a tile multiple, and update the overall routing mask.
The per-microbatch computational cost is , with small additional router overhead (5\% of router time on H100 for typical settings). Memory overhead is dominated by , which is negligible compared to full MoE activation storage (Guo et al., 16 Dec 2025).
4. Empirical Outcomes and Performance Impacts
Tile-aware token rounding eliminates padding waste by construction. At high sparsity, vanilla top-K incurs padding waste up to 30% of total GEMM compute, whereas TR reduces waste to zero. On H100 hardware, TR yields:
- Forward-pass TFLOPS improvements of 10%–26%
- Backward-pass TFLOPS improvements of 6%–12%
- Overall grouped GEMM speedup up to 1.16× in large, sparse MoEs (e.g., , , )
End-to-end throughput is greatly improved: for a $7$B parameter model trained using SonicMoE, 213B tokens/day are achieved on 64 H100s versus ScatterMoE’s 225B tokens/day on 96 H100s (Guo et al., 16 Dec 2025).
Quality metrics remain unaffected; downstream tasks evaluated after training with TR (using vanilla top-K for inference) show no significant loss and sometimes minor improvements ( difference in validation perplexity, parity in accuracy across benchmarks) even at extreme sparsity ().
5. Theoretical Context: Geometric Rounding and Deterministic Partitioning
Tile-aware rounding algorithms are intimately related to deterministic geometric rounding and partitioning of into "secluded" hypercube tiles, such that any -ball of radius intersects at most tiles (Woude et al., 2022). There exist explicit constructions with and , proven to be optimal up to constant factors for deterministic rounding. The core in both domains is the balance between minimal error (measured by neighborhood overlap or assignment deviations) and structural constraints (multiplicity, token conservation, hardware alignment).
The geometric insight clarifies the impossibility of achieving for fully deterministic schemes and substantiates the sharpness of rounding error bounds in high dimensions.
6. Applicability, Limitations, and Trade-offs
Tile-aware token rounding is most advantageous when is small (i.e., MoE sparsity is high), as the proportion of padding waste—and thus the performance gain—is maximized. Applicability is agnostic to expert granularity () and functional for any microbatch size satisfying , with robust empirical behavior observed for . The method is training-specific; at inference, reverting to vanilla top-K incurs negligible loss due to minimal distributional shift.
Potential degradation occurs for per-expert token counts below the tile threshold, where rounding effects are no longer statistically insignificant. The additional router overhead is negligible compared to the avoided GEMM waste.
7. Relation to Broader Rounding Algorithms
Generic tile-aware approaches bear resemblance to the secluded hypercube partition method for rounding in convex geometry, which is realized via an invertible upper-triangular matrix defining a lattice tiling arrangement for . The associated algorithm achieves optimal neighborhood overlap and rounding error, with a runtime of per sample and rigorous bounds on worst-case error and deterministic partition coloring (Woude et al., 2022). While these methods motivate and inform the design of computational rounding for MoE, implementation constraints in deep learning favor specialized, batch-aligned schemes that closely tie error control to hardware-imposed tile structures.
In summary, tile-aware token rounding is a principled, mathematically rigorous adaptation of deterministic rounding strategies for high-performance sparse MoE training, offering quantifiable efficiency gains and fidelity guarantees by exploiting alignment between token assignment and hardware matrix tiling (Guo et al., 16 Dec 2025, Woude et al., 2022).