TokenWeave: Optimized LLM Inference
- TokenWeave is a methodology for efficient compute-communication overlap in distributed LLM inference, utilizing smart token-splitting for balanced workload distribution.
- It employs a pipelined approach with a fused AllReduce+RMSNorm GPU kernel that minimizes communication overhead and recovers up to 40% of lost latency.
- Experimental benchmarks on 8×H100 systems demonstrate significant improvements in latency (up to 29%) and throughput (up to 26%) across various token batch sizes.
TokenWeave is a methodology for efficient compute-communication overlap in distributed LLM inference. It addresses the costly overhead—up to 20%—that arises in tensor-parallel LLM execution on systems with high-speed interconnects, such as NVLink-connected NVIDIA H100 GPUs. TokenWeave combines a minimal two-way, wave-aware token split with a fused AllReduce + RMSNorm GPU kernel leveraging Hopper-specific multimem instructions, enabling seamless overlap between communication and computation steps. This approach recovers significant latency and throughput lost to communication bottlenecks, even enabling surpassing of a counterfactual scenario with all communication removed in some settings (Gond et al., 16 May 2025).
1. Token-Splitting: Wave-Aware Bifurcation
TokenWeave’s foundation is a token-splitting (Token-Splitting) strategy that partitions an inference batch of tokens into two subsets of nearly equal computational work. Unlike naïve halving—likely leading to a doubled number of compute “waves” on streaming multiprocessors (SMs) due to quantization—the scheme determines split sizes using GPU occupancy analysis.
Given a large kernel (e.g., GEMM or attention) requiring cooperative thread arrays (CTAs), with each CTA occupying a single SM, the unsplit kernel executes in waves. Split sizes (with ) are selected such that . Practically, is chosen as —packing one full wave—while , so the sum of waves equals . This “smart-splitting” eliminates most quantization overhead, especially notable for small batch sizes, as demonstrated in Figure 1 of the reference (Gond et al., 16 May 2025).
2. Pipelined Overlap of Computation and Communication
With the batch split into two “wave-aware” halves (“split 0” and “split 1”), TokenWeave pipelines computation for one split while overlapping it with communication for the other. The method maps compute (Attention→RMSNorm→FFN) onto a dedicated CUDA stream (computeStream), and the corresponding AllReduce-focused communication (fused with RMSNorm) onto a separate stream (commStream).
The scheduling involves synchronization using CUDA events, such that as computeStream processes Attention or FFN on split 1, commStream executes the corresponding AllReduce + RMSNorm for split 0, and vice versa. This overlap is maintained per layer, ensuring that kernel granularity and split sizes are sufficient to keep GPUs efficiently occupied while minimizing the SM footprint of communication tasks; the communication kernel is constrained to 2–8 SMs, leaving abundant resources for computation.
3. RMSNorm Reordering and Communication Minimization
Standard tensor-parallel LLM inference places an RMSNorm after each Attention and FFN layer. Conventionally, this involves (A) local compute, (B) an AllReduce so each GPU holds the full hidden vector for every token, and (C) locally redundant RMSNorm on all tokens across all GPUs. TokenWeave exploits the fact that AllReduce decomposes as ReduceScatter followed by AllGather and, by rearranging this process, each GPU applies RMSNorm only to its local tokens—reducing the normalization workload by a factor of , eliminating redundancy.
This refactoring is only beneficial if further optimized, since naively performing ReduceScatter + AllGather is typically less efficient than a monolithic AllReduce. TokenWeave addresses this by fusing normalization, collective communication, and residual-addition within a single kernel, tightly integrating memory-bound RMSNorm steps and collective primitives.
4. Fused AllReduce + RMSNorm GPU Kernel
TokenWeave leverages PTX multimem instructions—exposed via NVLink SHARP/NVLS engines on Hopper GPUs—to implement a custom kernel that merges ReduceScatter, RMSNorm, and AllGather. The main operational flow is:
- Launch multimem-based in-network reduction (multimem_ld_reduce_add) to obtain partial results per token and accumulate residuals.
- Accumulate the sum of squares during reduction to compute variance on-the-fly, obviating a second read from HBM.
- Perform RMSNorm scaling and weighting.
- Write the normalized output directly to the AllGather buffer with multimem_st, saving a memory store.
- Complete the operation with minimal SM usage.
Notably, the entire fused routine consumes only 2–8 SMs (see Figure 2), ensuring the majority of SMs are available for the main compute kernel. This tight coupling results in an up to 40% speedup compared to separate AllReduce then RMSNorm steps, with communication and normalization cost effectively hidden behind computation (Gond et al., 16 May 2025).
| Step | Conventional Approach | TokenWeave Approach |
|---|---|---|
| Collectives | AllReduce, separate kernel | Fused via multimem in custom CTA kernel |
| RMSNorm application | redundant across GPUs | Local, only tokens/GPU |
| SM Utilization | (legacy) | 2–8 (on Hopper, with multimem) |
5. Quantitative Performance Evaluation
TokenWeave was benchmarked using 8×H100 DGX systems with NVLink4/NVSHARP and vLLM V1 integration:
- Standard AllReduce communication accounted for up to 23% of end-to-end inference latency.
- RMSNorm steps added 5–9% overhead.
- On Llama-3.3-70B (8 GPUs):
- At 1k tokens, TokenWeave reduced latency by 18% against a vLLM-multimem baseline.
- At 4k or more tokens, latency improved up to 29%, exceeding a “no-comm” counterfactual (vLLM-nocomm).
- Single-layer benchmarks showed per-layer latency gains of 20–38% across token lengths of 1k to 8k.
- End-to-end throughput (hybrid prefill/decode, 2k token chunk) increased by 20–26%.
- Throughput gains persisted across all chunk sizes from 1k to 8k (15–26%).
- Compared with NanoFlow, TokenWeave recovered approximately 20% of communication overhead, whereas NanoFlow attained 5–8% on H100 (Gond et al., 16 May 2025).
6. Limitations and Potential Extensions
TokenWeave depends on NVLink4/NVSHARP hardware and corresponding multimem instructions, currently exposed through PyTorch SymmetricMemory. On older GPUs or RDMA-based clusters lacking multimem, TokenWeave falls back to conventional NCCL AllReduce, which requires at least 16 SMs and diminishes the degree of overlap achievable.
Potential extensions include:
- Applying similar fused in-network collectives on AMD CDNA3 (with ROCm’s in-network primitives) or on InfiniBand via GPUDirect RDMA, though the granularity of SM allocation and in-network reduction capabilities will differ.
- Generalizing token splitting to more than two waves (e.g., three-way pipelining) for exceptionally large batches or for mixture-of-experts scheduling.
- Extending the fused normalization–collective kernel paradigm to gather-scatter operations in data-parallel training or parameter-server inference architectures.
In summary, TokenWeave offers a compelling strategy for eliminating dominant communication and normalization bottlenecks in distributed LLM inference, integrating seamlessly with modern CUDA/PyTorch workloads and affording substantial empirical gains in both latency and throughput at current 70B-scale model sizes (Gond et al., 16 May 2025).