RadixMLP: Deduplicating Transformer Computations
- RadixMLP is a stateless, single-pass technique that deduplicates position-wise computations in causal Transformers by leveraging shared prefixes.
- It constructs a prefix trie to map identical token positions, reducing redundancy and achieving notable throughput and latency improvements.
- RadixMLP integrates with ragged-layout inference engines using efficient CPU-GPU coordination and maintains near-perfect autograd compatibility.
RadixMLP is a stateless, single-pass procedure for reducing redundant position-wise computation in causal Transformer batch inference by leveraging prefix commonality among input sequences. Unlike approaches based on stateful KV caches, RadixMLP constructs a prefix trie over the batch to deduplicate computations for token positions that share identical causal history, enabling efficient reuse of embeddings, LayerNorms, linear projections, and multi-layer perceptron (MLP) activations. RadixMLP is compatible with ragged-layout inference engines and provides substantial throughput and latency improvements in both synthetic and real-world workloads, notably in retrieval and reranking scenarios involving common prefixes such as shared system prompts or few-shot exemplars (Feil et al., 21 Jan 2026).
1. Position-wise Computation and Redundancy in Causal Transformers
Causal Transformer architectures typically alternate self-attention and MLP sublayers at each block:
For a batch of sequences, if two tokens at layer share the same causal history—token sequence and the same position indices—the position-wise computations (embeddings, LayerNorms, projections, and MLP outputs) are identical by induction. Only attention ops depend on the full preceding context. For gated MLP blocks such as SwiGLU:
where , , and . The computational complexity per token is (generally ). In large Transformer models, position-wise operations dominate prefill computation, accounting for up to of floating-point operations.
2. Prefix Trie Construction, Gather, and Scatter Operations
RadixMLP deduplicates position-wise computation by constructing a prefix trie on the CPU for a batch of sequences , with total tokens . Each node in the trie represents a unique pair. Shared prefixes map to existing nodes, reducing redundancy; divergent branches create new nodes. Let the trie contain unique nodes, so gives the compact-token ratio, and is the compression ratio.
On GPU, the trie is not materialized; only index maps and are created. selects unique token positions: . maps back from compacted results: .
Trie and index construction pseudocode (Rust/C++) runs in sub-millisecond for tens of thousands of tokens and is fully asynchronous with GPU execution.
3. Formal Operations and Integration
Let represent ragged hidden states, and let be any position-wise layer. Standard inference computes , for . RadixMLP instead computes:
For each position-wise sublayer (MLP, LayerNorm, embeddings, etc.), only the compact buffer is processed before results are expanded back to the full batch via scatter. Attention layers operate on the expanded representations. Position IDs for RoPE and other operations are similarly compacted.
4. Computational Complexity and Expected Speedup
In conventional inference, position-wise layers incur FLOPs, with self-attention costing . RadixMLP reduces the position-wise cost to , plus two memory-bound gather/scatter passes at . Ideal speedup is:
where is the fraction of FLOPs attributable to position-wise layers:
with MLP architecture constants. For large hidden size , , and speedup approaches .
5. Empirical Results
RadixMLP delivers substantial performance improvements. In synthetic microbenchmarks (batch , shared prefixes $32$–$2048$ tokens, suffix $256$–$1024$ tokens), on Qwen3 models (0.6B/4B/8B parameters), observed speedups reach (8B, 2048-token prefix). For end-to-end serving on MS MARCO v1.1 using TEI and Qwen3-0.6B/4B/8B (ragged layout, FlashAttention-2), median latency improvements are reported as , , and , respectively.
| Model | Base Latency (s) | RadixMLP Latency (s) | Speedup |
|---|---|---|---|
| 0.6B | 0.78 | 0.54 | 1.44× |
| 4B | 3.76 | 2.42 | 1.56× |
| 8B | 5.96 | 3.74 | 1.59× |
Amdahl-style analytic predictions using observed and closely match measured speedups.
6. Implementation and Practical Considerations
RadixMLP is stateless—no persistent KV cache or eviction logic—and integrates into existing ragged-layout inference engines with minimal modification: two index-select calls (gather, scatter) around each attention sublayer. The memory overhead is 32-bit integers ( bytes), negligible relative to activations ( bytes in fp16). CUDA gather/scatter kernels deliver up to speedup vs naive index_select. The CPU scheduler (Rust) generates trie and indices in $0.1$–, three orders of magnitude below typical GPU times.
A threshold on compact-token ratio is commonly used: if , RadixMLP is skipped to avoid marginal overhead in batches lacking redundancy. For integration, compute compact position IDs for RoPE, conduct all position-wise operations in compact space, scatter QKV to the full layout, perform attention, then gather for subsequent steps.
RadixMLP is compatible with autograd; gather in the forward pass is index_select, with backward pass implemented as scatter_add to aggregate gradients at duplicated locations. Empirical checks indicate gradient discrepancies in multi-layer models, with larger differences attributable to attention kernel variation.
By horizontally deduplicating position-wise compute for shared prefixes, RadixMLP achieves up to acceleration in synthetic settings and $1.4$– in real reranking inference, with only memory overhead and stateless CPU-GPU coordination (Feil et al., 21 Jan 2026).