Papers
Topics
Authors
Recent
Search
2000 character limit reached

Token-Level Sharding

Updated 21 April 2026
  • Token-level sharding is a distributed strategy that partitions tokens or activations across devices to overcome memory and compute bottlenecks in large-scale models.
  • MoEShard and Helix Parallelism implement sharding by splitting expert matrices and KV caches, ensuring balanced workloads and sub-linear scaling of resources.
  • This approach minimizes latency and memory overhead while preserving full token retention, improving efficiency in long-context and expert-skew scenarios.

Token-level sharding refers to the class of distributed inference strategies in which model state or activations are partitioned across hardware resources along the token (sequence position) axis, as opposed to traditional approaches that shard parameters or splits computation along feature, head, or expert axes. This paradigm has emerged as a critical technique for addressing memory and compute bottlenecks in large-scale LLMs and Mixture of Experts (MoE) networks, particularly for scenarios involving long input sequences or severe expert-routing skew. Leading exemplar systems include MoEShard for encoder-based MoEs and Helix Parallelism for multi-million-token LLM decoding (Balmau et al., 11 Mar 2025, Bhatia et al., 7 Jul 2025).

1. Motivation and Context

Token-level sharding was developed to address two central inefficiencies in large model inference:

  • Imbalanced load in MoEs: Conventional expert-parallel MoE systems assign full experts to devices and route tokens accordingly. Skewed token-to-expert routing can result in straggler devices—so-called “hot” experts—becoming throughput bottlenecks and others sitting idle. Standard approaches such as capacity factors or token dropping do not resolve the underutilization and may harm accuracy (Balmau et al., 11 Mar 2025).
  • Mega-scale KV cache for long-context LLMs: For LLMs with multi-million-token retrieval histories, key-value (KV) caches for attention layers can exceed the memory and read bandwidth capabilities of individual devices. Conventional tensor parallelism (TP) can reduce weight-read costs, but once TP degree exceeds the number of attention heads, KV data must be replicated, leading to inefficient DRAM usage and limited batch scaling (Bhatia et al., 7 Jul 2025).

Token-level sharding breaks these barriers by partitioning tokens, token-activation slices, or intermediate activations across devices, thereby facilitating full hardware utilization and sub-linear scaling of memory/compute costs.

2. Token-Level Sharding in Mixture of Experts: MoEShard

MoEShard is an inference system that achieves perfect load balancing in encoder-based MoEs through a highly granular tensor sharding scheme applied to expert parameters (Balmau et al., 11 Mar 2025). The salient mechanisms are:

  • Expert matrix decomposition: Each expert’s two weight matrices, WiRh1×h2W_i \in \mathbb{R}^{h_1 \times h_2} and WoRh2×h1W_o \in \mathbb{R}^{h_2 \times h_1}, are sliced across G|G| GPUs: WiW_i by columns, WoW_o by rows. Each GPU thus holds a shard Wi(e,g)W_i^{(e,g)} and Wo(e,g)W_o^{(e,g)} for every expert ee.
  • Token broadcast and computation: All input tokens are broadcast to all GPUs. For each token routed to an expert ee, every GPU computes a partial matrix multiplication using its local shard, producing yt(e,g)=xtWi(e,g)Wo(e,g)y_t^{(e,g)} = x_t W_i^{(e,g)} W_o^{(e,g)}.
  • Partial result reduction: The global expert output per token is reconstructed by summing the per-shard results across GPUs: WoRh2×h1W_o \in \mathbb{R}^{h_2 \times h_1}0. This guarantees each GPU computes an identical workload, enforcing perfect load balance and obviating routing-induced stragglers.

The table below summarizes MoEShard’s main distinctions:

Property Conventional Expert-Parallel MoEShard Token-Level Sharding
Expert assignment Disjoint on devices Sharded (col/row) on all GPUs
Token routing Only to "host" GPU All tokens broadcast to all
Load balancing Skew, bottlenecked Perfectly balanced
Token retention May drop tokens 100% accurate, no drop

MoEShard further fuses per-expert sparse matrix multiplications into two block-sparse kernels, minimizing kernel-launch overhead and optimizing throughput.

3. Token-Level KV Sharding in LLMs: Helix Parallelism

In long-context LLMs, Helix Parallelism shards the attention module’s KV cache across GPUs along the token axis, enabling efficient inference and communication scaling (Bhatia et al., 7 Jul 2025).

  • KV-parallel slicing: The KV cache tensor of length WoRh2×h1W_o \in \mathbb{R}^{h_2 \times h_1}1 (sequence) is partitioned among WoRh2×h1W_o \in \mathbb{R}^{h_2 \times h_1}2 GPUs, each storing a segment WoRh2×h1W_o \in \mathbb{R}^{h_2 \times h_1}3. No GPU ever holds more than its local slice.
  • Local attention computation: Upon each query, the query embedding is broadcast to all KV-parallel ranks. Each then performs Q/K/V projection and runs FlashAttention using its local cache fragment, generating a partial context vector and log-sum-exp scalar.
  • All-to-all exchange: A single all-to-all communication exchanges these partial results along the query-head axis. Each rank sums and rescales received fragments to reconstruct the full softmax-normalized attention output; exact attention behavior is preserved.
  • Temporal hybridization with TP: After each layer’s attention phase, Helix swaps the token-sharded ranks into a standard TP configuration for the FFN. This decouples the tradeoffs between KV-cache memory and FFN weight-read scaling.

Helix’s token-level sharding (KV parallelism) avoids the cache-duplication cliff present in TP-only schemes when TP width exceeds head count, supporting sublinear DRAM scaling with respect to context length.

4. Communication and Computational Workflow

Communication and compute strategies are central for performance in token-level sharding.

  • MoEShard: Only routing metadata (a WoRh2×h1W_o \in \mathbb{R}^{h_2 \times h_1}4 integer matrix) is exchanged initially. All token embeddings are then broadcast. GPUs execute the same set of fused expert shard computations, and output tokens are summed across devices.
  • Helix Parallelism: Each GPU receives every query embedding, performs its local attention computation, and uses an all-to-all exchange to communicate partial context and normalization scalars. Attention output is reconstructed by weighted summation.

Helix’s HOP-B overlap pipeline further hides the communication latency by pipelining compute for token WoRh2×h1W_o \in \mathbb{R}^{h_2 \times h_1}5 while exchanging results for token WoRh2×h1W_o \in \mathbb{R}^{h_2 \times h_1}6. As a result, exposed per-token token-to-token latency (TTL) is reduced to the maximum of compute and communication, rather than their sum.

5. Cost Models and Performance Analysis

Both MoEShard and Helix provide formal cost models predicting their throughput, latency, and scaling profiles.

  • MoEShard: The per-GPU compute load is WoRh2×h1W_o \in \mathbb{R}^{h_2 \times h_1}7, invariant to expert-routing skew. Time-to-first-token (TTFT) is

WoRh2×h1W_o \in \mathbb{R}^{h_2 \times h_1}8

with compute reduced ideally by WoRh2×h1W_o \in \mathbb{R}^{h_2 \times h_1}9 compared to conventional expert-parallel inference. In empirical measurements, up to G|G|0 lower TTFT was observed versus DeepSpeed expert parallel on A100s with heavy expert skew, and 100% token retention was maintained (Balmau et al., 11 Mar 2025).

  • Helix: Per-layer KV-cache read time is

G|G|1

scaling inversely with G|G|2. Communication per token per layer is G|G|3, independent of sequence length G|G|4. Helix supports up to G|G|5 more users at fixed TTL and achieves G|G|6 lower TTL, or G|G|7 higher batch throughput, than TP-only or prior KVP+fixed TP schemes for massive LLMs (Bhatia et al., 7 Jul 2025).

6. Trade-offs, Limitations, and Best-case Scenarios

Token-level sharding introduces several trade-offs:

  • Communication cost: Both approaches require full broadcast of token embeddings or query projections to all relevant devices. For MoEShard the cost is small compared to expert compute on fast interconnects (0.1–0.2 ms per round), but may become a bottleneck with extreme sequence length or slow hardware (Balmau et al., 11 Mar 2025). In Helix, HOP-B overlap hides almost all of this cost.
  • Memory overhead: MoEShard’s replication of token embeddings (each GPU requires a full live copy) doubles the live token memory footprint relative to token-partitioned baselines. In Helix, sharding the KV-cache along tokens ensures no device holds the full cache.
  • Kernel launch efficiency: Without fusion, the number of kernel launches (G|G|8 in MoEShard) would be prohibitive for small G|G|9 or without sparse fusion support.
  • Effectiveness dependence on skew/scale: MoEShard’s advantage is maximized under severe expert-routing skew. If token-to-expert mapping is already balanced or batch size is very small, gains diminish. Helix is most advantageous when context length is extremely large and batch size scaling would otherwise bottleneck on DRAM or KV duplication.

7. Comparison with Alternative Parallelism Strategies

Token-level sharding provides advantages over tensor-parallel (TP), data-parallel (DP), and expert-parallel approaches by fundamentally decoupling the dimensions along which compute, memory, and communication scale.

  • TP-only: Once TP width exceeds the number of KV heads, every rank must store the whole cache, eliminating DRAM scaling benefits. In FFNs, TP parallelism helps with weight reads but is limited by the TP/heads constraint.
  • DP: KV cache is fully replicated on each device, forcing linear memory scaling.
  • Helix: Supports arbitrary scaling in KV-parallel (KVP) width for cache, while still leveraging large TP width for FFNs. A single all-to-all per layer over query-head axis suffices for exact attention. This enables simultaneous sublinear KV scaling and linear acceleration of weight-reads, outperforming TP-only and prior KVP approaches by large margins for both TTL and throughput at scale (Bhatia et al., 7 Jul 2025).

Empirical results from both MoEShard and Helix confirm that token-level sharding unlocks practical, efficient, and balanced inference for both Mixture of Experts and long-context LLMs, and extends the throughput-latency Pareto frontier for multi-GPU inference.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (2)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Token-Level Sharding.