Token-Level Sharding
- Token-level sharding is a distributed strategy that partitions tokens or activations across devices to overcome memory and compute bottlenecks in large-scale models.
- MoEShard and Helix Parallelism implement sharding by splitting expert matrices and KV caches, ensuring balanced workloads and sub-linear scaling of resources.
- This approach minimizes latency and memory overhead while preserving full token retention, improving efficiency in long-context and expert-skew scenarios.
Token-level sharding refers to the class of distributed inference strategies in which model state or activations are partitioned across hardware resources along the token (sequence position) axis, as opposed to traditional approaches that shard parameters or splits computation along feature, head, or expert axes. This paradigm has emerged as a critical technique for addressing memory and compute bottlenecks in large-scale LLMs and Mixture of Experts (MoE) networks, particularly for scenarios involving long input sequences or severe expert-routing skew. Leading exemplar systems include MoEShard for encoder-based MoEs and Helix Parallelism for multi-million-token LLM decoding (Balmau et al., 11 Mar 2025, Bhatia et al., 7 Jul 2025).
1. Motivation and Context
Token-level sharding was developed to address two central inefficiencies in large model inference:
- Imbalanced load in MoEs: Conventional expert-parallel MoE systems assign full experts to devices and route tokens accordingly. Skewed token-to-expert routing can result in straggler devices—so-called “hot” experts—becoming throughput bottlenecks and others sitting idle. Standard approaches such as capacity factors or token dropping do not resolve the underutilization and may harm accuracy (Balmau et al., 11 Mar 2025).
- Mega-scale KV cache for long-context LLMs: For LLMs with multi-million-token retrieval histories, key-value (KV) caches for attention layers can exceed the memory and read bandwidth capabilities of individual devices. Conventional tensor parallelism (TP) can reduce weight-read costs, but once TP degree exceeds the number of attention heads, KV data must be replicated, leading to inefficient DRAM usage and limited batch scaling (Bhatia et al., 7 Jul 2025).
Token-level sharding breaks these barriers by partitioning tokens, token-activation slices, or intermediate activations across devices, thereby facilitating full hardware utilization and sub-linear scaling of memory/compute costs.
2. Token-Level Sharding in Mixture of Experts: MoEShard
MoEShard is an inference system that achieves perfect load balancing in encoder-based MoEs through a highly granular tensor sharding scheme applied to expert parameters (Balmau et al., 11 Mar 2025). The salient mechanisms are:
- Expert matrix decomposition: Each expert’s two weight matrices, and , are sliced across GPUs: by columns, by rows. Each GPU thus holds a shard and for every expert .
- Token broadcast and computation: All input tokens are broadcast to all GPUs. For each token routed to an expert , every GPU computes a partial matrix multiplication using its local shard, producing .
- Partial result reduction: The global expert output per token is reconstructed by summing the per-shard results across GPUs: 0. This guarantees each GPU computes an identical workload, enforcing perfect load balance and obviating routing-induced stragglers.
The table below summarizes MoEShard’s main distinctions:
| Property | Conventional Expert-Parallel | MoEShard Token-Level Sharding |
|---|---|---|
| Expert assignment | Disjoint on devices | Sharded (col/row) on all GPUs |
| Token routing | Only to "host" GPU | All tokens broadcast to all |
| Load balancing | Skew, bottlenecked | Perfectly balanced |
| Token retention | May drop tokens | 100% accurate, no drop |
MoEShard further fuses per-expert sparse matrix multiplications into two block-sparse kernels, minimizing kernel-launch overhead and optimizing throughput.
3. Token-Level KV Sharding in LLMs: Helix Parallelism
In long-context LLMs, Helix Parallelism shards the attention module’s KV cache across GPUs along the token axis, enabling efficient inference and communication scaling (Bhatia et al., 7 Jul 2025).
- KV-parallel slicing: The KV cache tensor of length 1 (sequence) is partitioned among 2 GPUs, each storing a segment 3. No GPU ever holds more than its local slice.
- Local attention computation: Upon each query, the query embedding is broadcast to all KV-parallel ranks. Each then performs Q/K/V projection and runs FlashAttention using its local cache fragment, generating a partial context vector and log-sum-exp scalar.
- All-to-all exchange: A single all-to-all communication exchanges these partial results along the query-head axis. Each rank sums and rescales received fragments to reconstruct the full softmax-normalized attention output; exact attention behavior is preserved.
- Temporal hybridization with TP: After each layer’s attention phase, Helix swaps the token-sharded ranks into a standard TP configuration for the FFN. This decouples the tradeoffs between KV-cache memory and FFN weight-read scaling.
Helix’s token-level sharding (KV parallelism) avoids the cache-duplication cliff present in TP-only schemes when TP width exceeds head count, supporting sublinear DRAM scaling with respect to context length.
4. Communication and Computational Workflow
Communication and compute strategies are central for performance in token-level sharding.
- MoEShard: Only routing metadata (a 4 integer matrix) is exchanged initially. All token embeddings are then broadcast. GPUs execute the same set of fused expert shard computations, and output tokens are summed across devices.
- Helix Parallelism: Each GPU receives every query embedding, performs its local attention computation, and uses an all-to-all exchange to communicate partial context and normalization scalars. Attention output is reconstructed by weighted summation.
Helix’s HOP-B overlap pipeline further hides the communication latency by pipelining compute for token 5 while exchanging results for token 6. As a result, exposed per-token token-to-token latency (TTL) is reduced to the maximum of compute and communication, rather than their sum.
5. Cost Models and Performance Analysis
Both MoEShard and Helix provide formal cost models predicting their throughput, latency, and scaling profiles.
- MoEShard: The per-GPU compute load is 7, invariant to expert-routing skew. Time-to-first-token (TTFT) is
8
with compute reduced ideally by 9 compared to conventional expert-parallel inference. In empirical measurements, up to 0 lower TTFT was observed versus DeepSpeed expert parallel on A100s with heavy expert skew, and 100% token retention was maintained (Balmau et al., 11 Mar 2025).
- Helix: Per-layer KV-cache read time is
1
scaling inversely with 2. Communication per token per layer is 3, independent of sequence length 4. Helix supports up to 5 more users at fixed TTL and achieves 6 lower TTL, or 7 higher batch throughput, than TP-only or prior KVP+fixed TP schemes for massive LLMs (Bhatia et al., 7 Jul 2025).
6. Trade-offs, Limitations, and Best-case Scenarios
Token-level sharding introduces several trade-offs:
- Communication cost: Both approaches require full broadcast of token embeddings or query projections to all relevant devices. For MoEShard the cost is small compared to expert compute on fast interconnects (0.1–0.2 ms per round), but may become a bottleneck with extreme sequence length or slow hardware (Balmau et al., 11 Mar 2025). In Helix, HOP-B overlap hides almost all of this cost.
- Memory overhead: MoEShard’s replication of token embeddings (each GPU requires a full live copy) doubles the live token memory footprint relative to token-partitioned baselines. In Helix, sharding the KV-cache along tokens ensures no device holds the full cache.
- Kernel launch efficiency: Without fusion, the number of kernel launches (8 in MoEShard) would be prohibitive for small 9 or without sparse fusion support.
- Effectiveness dependence on skew/scale: MoEShard’s advantage is maximized under severe expert-routing skew. If token-to-expert mapping is already balanced or batch size is very small, gains diminish. Helix is most advantageous when context length is extremely large and batch size scaling would otherwise bottleneck on DRAM or KV duplication.
7. Comparison with Alternative Parallelism Strategies
Token-level sharding provides advantages over tensor-parallel (TP), data-parallel (DP), and expert-parallel approaches by fundamentally decoupling the dimensions along which compute, memory, and communication scale.
- TP-only: Once TP width exceeds the number of KV heads, every rank must store the whole cache, eliminating DRAM scaling benefits. In FFNs, TP parallelism helps with weight reads but is limited by the TP/heads constraint.
- DP: KV cache is fully replicated on each device, forcing linear memory scaling.
- Helix: Supports arbitrary scaling in KV-parallel (KVP) width for cache, while still leveraging large TP width for FFNs. A single all-to-all per layer over query-head axis suffices for exact attention. This enables simultaneous sublinear KV scaling and linear acceleration of weight-reads, outperforming TP-only and prior KVP approaches by large margins for both TTL and throughput at scale (Bhatia et al., 7 Jul 2025).
Empirical results from both MoEShard and Helix confirm that token-level sharding unlocks practical, efficient, and balanced inference for both Mixture of Experts and long-context LLMs, and extends the throughput-latency Pareto frontier for multi-GPU inference.