Papers
Topics
Authors
Recent
Search
2000 character limit reached

RadixMLP: Deduplicating Transformer Computations

Updated 29 January 2026
  • RadixMLP is a stateless, single-pass technique that deduplicates position-wise computations in causal Transformers by leveraging shared prefixes.
  • It constructs a prefix trie to map identical token positions, reducing redundancy and achieving notable throughput and latency improvements.
  • RadixMLP integrates with ragged-layout inference engines using efficient CPU-GPU coordination and maintains near-perfect autograd compatibility.

RadixMLP is a stateless, single-pass procedure for reducing redundant position-wise computation in causal Transformer batch inference by leveraging prefix commonality among input sequences. Unlike approaches based on stateful KV caches, RadixMLP constructs a prefix trie over the batch to deduplicate computations for token positions that share identical causal history, enabling efficient reuse of embeddings, LayerNorms, linear projections, and multi-layer perceptron (MLP) activations. RadixMLP is compatible with ragged-layout inference engines and provides substantial throughput and latency improvements in both synthetic and real-world workloads, notably in retrieval and reranking scenarios involving common prefixes such as shared system prompts or few-shot exemplars (Feil et al., 21 Jan 2026).

1. Position-wise Computation and Redundancy in Causal Transformers

Causal Transformer architectures typically alternate self-attention and MLP sublayers at each block:

  • Hl′=Hl+Attn(LayerNorm(Hl))H'_l = H_l + \mathrm{Attn}(\mathrm{LayerNorm}(H_l))
  • Hl+1=Hl′+MLP(LayerNorm(Hl′))H_{l+1} = H'_l + \mathrm{MLP}(\mathrm{LayerNorm}(H'_l))

For a batch of sequences, if two tokens at layer ll share the same causal history—token sequence (t1,…,tk)(t_1, \ldots, t_k) and the same position indices—the position-wise computations (embeddings, LayerNorms, projections, and MLP outputs) are identical by induction. Only attention ops depend on the full preceding context. For gated MLP blocks such as SwiGLU:

MLP(h)=Wdown(σ(Wgateh)⊙(Wuph))\mathrm{MLP}(h) = W_{\mathrm{down}}\bigl(\sigma(W_{\mathrm{gate}}h)\odot (W_{\mathrm{up}}h)\bigr)

where h∈Rdh \in \mathbb{R}^d, Wup,Wgate∈Rdint×dW_{\mathrm{up}}, W_{\mathrm{gate}} \in \mathbb{R}^{d_{\mathrm{int}} \times d}, and Wdown∈Rd×dintW_{\mathrm{down}} \in \mathbb{R}^{d \times d_{\mathrm{int}}}. The computational complexity per token is O(ddint)\mathcal{O}(d d_{\mathrm{int}}) (generally O(d2)\mathcal{O}(d^2)). In large Transformer models, position-wise operations dominate prefill computation, accounting for up to 92%92\% of floating-point operations.

2. Prefix Trie Construction, Gather, and Scatter Operations

RadixMLP deduplicates position-wise computation by constructing a prefix trie on the CPU for a batch of BB sequences {Si}\{ S_i \}, with total tokens N=∑iLiN = \sum_{i} L_i. Each node in the trie represents a unique (token_id,position_id)(\mathrm{token\_id}, \mathrm{position\_id}) pair. Shared prefixes map to existing nodes, reducing redundancy; divergent branches create new nodes. Let the trie contain N′N' unique nodes, so γ=N′/N\gamma = N'/N gives the compact-token ratio, and r=N/N′r = N/N' is the compression ratio.

On GPU, the trie is not materialized; only index maps IgatherI_{\mathrm{gather}} and IscatterI_{\mathrm{scatter}} are created. Igather∈{0,…,N−1}N′I_{\mathrm{gather}} \in \{0, \ldots, N-1\}^{N'} selects unique token positions: Xunique[i]=Xorig[Igather[i]]X_{\mathrm{unique}}[i] = X_{\mathrm{orig}}[I_{\mathrm{gather}}[i]]. Iscatter∈{0,…,N′−1}NI_{\mathrm{scatter}} \in \{0, \ldots, N'-1\}^{N} maps back from compacted results: Yrestored[j]=Yunique[Iscatter[j]]Y_{\mathrm{restored}}[j] = Y_{\mathrm{unique}}[I_{\mathrm{scatter}}[j]].

Trie and index construction pseudocode (Rust/C++) runs in sub-millisecond for tens of thousands of tokens and is fully asynchronous with GPU execution.

3. Formal Operations and Integration

Let X∈RN×dX \in \mathbb{R}^{N \times d} represent ragged hidden states, and let f:Rd→Rd′f : \mathbb{R}^d \to \mathbb{R}^{d'} be any position-wise layer. Standard inference computes Y=f(X)Y = f(X), Yk=f(Xk)Y_k = f(X_k) for k=1,…,Nk = 1, \ldots, N. RadixMLP instead computes:

X′=X[Igather]∈RN′×d Y′=f(X′)∈RN′×d′ Y=Y′[Iscatter]∈RN×d′X' = X[I_{\mathrm{gather}}] \in \mathbb{R}^{N' \times d} \ Y' = f(X') \in \mathbb{R}^{N' \times d'} \ Y = Y'[I_{\mathrm{scatter}}] \in \mathbb{R}^{N \times d'}

For each position-wise sublayer (MLP, LayerNorm, embeddings, etc.), only the compact buffer is processed before results are expanded back to the full batch via scatter. Attention layers operate on the expanded representations. Position IDs for RoPE and other operations are similarly compacted.

4. Computational Complexity and Expected Speedup

In conventional inference, position-wise layers incur O(Nd2)\mathcal{O}(N d^2) FLOPs, with self-attention costing O(NLd)\mathcal{O}(N L d). RadixMLP reduces the position-wise cost to O(N′d2)\mathcal{O}(N' d^2), plus two memory-bound gather/scatter passes at O(Nd)\mathcal{O}(N d). Ideal speedup is:

Speedup≈Nd2+NLdN′d2+NLd=1(1−fc)+fc/r\text{Speedup} \approx \frac{N d^2 + N L d}{N' d^2 + N L d} = \frac{1}{(1 - f_c) + f_c / r}

where fcf_c is the fraction of FLOPs attributable to position-wise layers:

fc=αd2αd2+βLdf_c = \frac{\alpha d^2}{\alpha d^2 + \beta L d}

with α,β\alpha,\beta MLP architecture constants. For large hidden size dd, fc→1f_c \to 1, and speedup approaches rr.

5. Empirical Results

RadixMLP delivers substantial performance improvements. In synthetic microbenchmarks (batch B=32B = 32, shared prefixes $32$–$2048$ tokens, suffix $256$–$1024$ tokens), on Qwen3 models (0.6B/4B/8B parameters), observed speedups reach 5.0×5.0\times (8B, 2048-token prefix). For end-to-end serving on MS MARCO v1.1 using TEI and Qwen3-0.6B/4B/8B (ragged layout, FlashAttention-2), median latency improvements are reported as 1.44×1.44\times, 1.56×1.56\times, and 1.59×1.59\times, respectively.

Model Base Latency (s) RadixMLP Latency (s) Speedup
0.6B 0.78 0.54 1.44×
4B 3.76 2.42 1.56×
8B 5.96 3.74 1.59×

Amdahl-style analytic predictions using observed γ\gamma and fcf_c closely match measured speedups.

6. Implementation and Practical Considerations

RadixMLP is stateless—no persistent KV cache or eviction logic—and integrates into existing ragged-layout inference engines with minimal modification: two index-select calls (gather, scatter) around each attention sublayer. The memory overhead is O(N)\mathcal{O}(N) 32-bit integers (≈4N\approx 4N bytes), negligible relative to activations (≈2Nd\approx 2Nd bytes in fp16). CUDA gather/scatter kernels deliver up to 22×22\times speedup vs naive index_select. The CPU scheduler (Rust) generates trie and indices in $0.1$–2.5 ms2.5~\mathrm{ms}, three orders of magnitude below typical GPU times.

A threshold on compact-token ratio γ\gamma is commonly used: if γ>0.95\gamma > 0.95, RadixMLP is skipped to avoid marginal overhead in batches lacking redundancy. For integration, compute compact position IDs for RoPE, conduct all position-wise operations in compact space, scatter QKV to the full layout, perform attention, then gather for subsequent steps.

RadixMLP is compatible with autograd; gather in the forward pass is index_select, with backward pass implemented as scatter_add to aggregate gradients at duplicated locations. Empirical checks indicate gradient discrepancies <2×10−5<2\times 10^{-5} in multi-layer models, with larger differences attributable to attention kernel variation.

By horizontally deduplicating position-wise compute for shared prefixes, RadixMLP achieves up to 5×5\times acceleration in synthetic settings and $1.4$–1.6×1.6\times in real reranking inference, with only O(N)\mathcal{O}(N) memory overhead and stateless CPU-GPU coordination (Feil et al., 21 Jan 2026).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to RadixMLP.