Papers
Topics
Authors
Recent
Search
2000 character limit reached

MoE Multi-Token Prediction (MTP) Layer

Updated 12 June 2026
  • MoE Multi-Token Prediction (MTP) layer is a specialized component that predicts several future tokens, improving overall language model performance.
  • It integrates with a mixture-of-experts architecture by leveraging gated routing, small transformer submodules, and a shared output head to manage predictions.
  • Empirical results indicate performance improvements of up to 1 point on benchmarks and higher multi-token acceptance rates in speculative decoding.

The MoE Multi-Token Prediction (MTP) layer is a specialized architectural component integrated into the SlimQwen model’s mixture-of-experts (MoE) backbone. Designed to augment conventional next-token prediction, the MTP layer enables direct supervision over a short span of future tokens at each position. This approach synergizes with knowledge distillation (KD) and pruning-based compression, enhancing model efficiency and language modeling performance, particularly in downstream knowledge-intensive benchmarks (Tang et al., 9 May 2026).

1. Architectural Integration of the MTP Layer

In SlimQwen, the transformer block processes each input token embedding xix_i through LL alternating Gated Attention (or Gated DeltaNet) and MoE sublayers. The MoE sublayer’s router computes

z(x)=softmax(TopK(xWG,k))Rnroutedz(x) = \mathrm{softmax}(\mathrm{TopK}(x W^G, k)) \in \mathbb{R}^{n_{\text{routed}}}

for routed experts, along with a shared-expert gate

zs(x)=σ(xwsh)Rnsharedz_s(x) = \sigma(x w_{\text{sh}}) \in \mathbb{R}^{n_{\text{shared}}}

Each expert employs a SwiGLU MLP, and the MoE output combines routed and shared expert results: MoE(x)=e=1nroutedze(x)Experte(x)+s=1nsharedzs(x)Experts(x)\text{MoE}(x)=\sum_{e=1}^{n_{\text{routed}}} z_e(x)\,\text{Expert}_e(x) + \sum_{s=1}^{n_{\text{shared}}} z_s(x)\,\text{Expert}_s(x) After passing through all layers, the architecture yields hidden states h1:T0RT×dh^0_{1:T} \in \mathbb{R}^{T \times d}. The MTP module operates atop these representations, generalizing next-token prediction to simultaneous prediction of up to DD future tokens per position.

For each position ii and prediction depth k=1,,Dk=1,\ldots,D, MTP:

  • Normalizes and concatenates hik1h_i^{k-1} with the embedding LL0,
  • Projects the concatenated vector to LL1 dimensions via LL2,
  • Applies a small transformer block LL3 (unique per LL4),
  • Uses a shared linear “OutHead” projecting to the vocabulary size.

The process is formalized by: LL5 where LL6 is the predicted distribution over the vocabulary at future offset LL7. For LL8, this reduces to traditional next-token prediction; for LL9, multiple future tokens are predicted in parallel.

2. Multi-Token Distillation and Training Objective

The training objective jointly balances the standard next-token LM losses, their distillation (KD) analogs, and corresponding MTP losses for all depths z(x)=softmax(TopK(xWG,k))Rnroutedz(x) = \mathrm{softmax}(\mathrm{TopK}(x W^G, k)) \in \mathbb{R}^{n_{\text{routed}}}0 in z(x)=softmax(TopK(xWG,k))Rnroutedz(x) = \mathrm{softmax}(\mathrm{TopK}(x W^G, k)) \in \mathbb{R}^{n_{\text{routed}}}1. For a sequence length z(x)=softmax(TopK(xWG,k))Rnroutedz(x) = \mathrm{softmax}(\mathrm{TopK}(x W^G, k)) \in \mathbb{R}^{n_{\text{routed}}}2 with vocabulary of size z(x)=softmax(TopK(xWG,k))Rnroutedz(x) = \mathrm{softmax}(\mathrm{TopK}(x W^G, k)) \in \mathbb{R}^{n_{\text{routed}}}3, the losses are:

  • Next-token LM: z(x)=softmax(TopK(xWG,k))Rnroutedz(x) = \mathrm{softmax}(\mathrm{TopK}(x W^G, k)) \in \mathbb{R}^{n_{\text{routed}}}4
  • Next-token KD: z(x)=softmax(TopK(xWG,k))Rnroutedz(x) = \mathrm{softmax}(\mathrm{TopK}(x W^G, k)) \in \mathbb{R}^{n_{\text{routed}}}5 (where z(x)=softmax(TopK(xWG,k))Rnroutedz(x) = \mathrm{softmax}(\mathrm{TopK}(x W^G, k)) \in \mathbb{R}^{n_{\text{routed}}}6 is the teacher's distribution)
  • MTP LM: z(x)=softmax(TopK(xWG,k))Rnroutedz(x) = \mathrm{softmax}(\mathrm{TopK}(x W^G, k)) \in \mathbb{R}^{n_{\text{routed}}}7
  • MTP KD: z(x)=softmax(TopK(xWG,k))Rnroutedz(x) = \mathrm{softmax}(\mathrm{TopK}(x W^G, k)) \in \mathbb{R}^{n_{\text{routed}}}8

The total objective combines these terms with scheduled weights: z(x)=softmax(TopK(xWG,k))Rnroutedz(x) = \mathrm{softmax}(\mathrm{TopK}(x W^G, k)) \in \mathbb{R}^{n_{\text{routed}}}9 Here, zs(x)=σ(xwsh)Rnsharedz_s(x) = \sigma(x w_{\text{sh}}) \in \mathbb{R}^{n_{\text{shared}}}0 decays linearly from zs(x)=σ(xwsh)Rnsharedz_s(x) = \sigma(x w_{\text{sh}}) \in \mathbb{R}^{n_{\text{shared}}}1 to zs(x)=σ(xwsh)Rnsharedz_s(x) = \sigma(x w_{\text{sh}}) \in \mathbb{R}^{n_{\text{shared}}}2 and zs(x)=σ(xwsh)Rnsharedz_s(x) = \sigma(x w_{\text{sh}}) \in \mathbb{R}^{n_{\text{shared}}}3 cosine-decays from zs(x)=σ(xwsh)Rnsharedz_s(x) = \sigma(x w_{\text{sh}}) \in \mathbb{R}^{n_{\text{shared}}}4 to zs(x)=σ(xwsh)Rnsharedz_s(x) = \sigma(x w_{\text{sh}}) \in \mathbb{R}^{n_{\text{shared}}}5 over training.

3. Interaction with MoE Gating and Expert Selection

The MTP layer operates atop the MoE backbone but remains tightly coupled with the expert routing mechanism. Gradients arising from the MTP head propagate through the entire MoE network, including router gates. The partial derivative of the MTP-KD term with respect to expert parameters zs(x)=σ(xwsh)Rnsharedz_s(x) = \sigma(x w_{\text{sh}}) \in \mathbb{R}^{n_{\text{shared}}}6 is modulated by the gating scores zs(x)=σ(xwsh)Rnsharedz_s(x) = \sigma(x w_{\text{sh}}) \in \mathbb{R}^{n_{\text{shared}}}7, enabling the router to forward tokens for which the expert selection improves future-token prediction. No new gating mechanism is introduced within MTP itself, as it leverages the MoE’s pre-existing soft-top-k routing protocol.

4. MTP Forward and Backward Pass

The core stages in the MTP layer’s dataflow are summarized below (see (Tang et al., 9 May 2026), pseudocode section):

  • Forward pass:
    • Initialize zs(x)=σ(xwsh)Rnsharedz_s(x) = \sigma(x w_{\text{sh}}) \in \mathbb{R}^{n_{\text{shared}}}8 from token embeddings.
    • Iterate through zs(x)=σ(xwsh)Rnsharedz_s(x) = \sigma(x w_{\text{sh}}) \in \mathbb{R}^{n_{\text{shared}}}9 transformer blocks to obtain MoE(x)=e=1nroutedze(x)Experte(x)+s=1nsharedzs(x)Experts(x)\text{MoE}(x)=\sum_{e=1}^{n_{\text{routed}}} z_e(x)\,\text{Expert}_e(x) + \sum_{s=1}^{n_{\text{shared}}} z_s(x)\,\text{Expert}_s(x)0 with router gating.
    • For each prediction depth MoE(x)=e=1nroutedze(x)Experte(x)+s=1nsharedzs(x)Experts(x)\text{MoE}(x)=\sum_{e=1}^{n_{\text{routed}}} z_e(x)\,\text{Expert}_e(x) + \sum_{s=1}^{n_{\text{shared}}} z_s(x)\,\text{Expert}_s(x)1 in MoE(x)=e=1nroutedze(x)Experte(x)+s=1nsharedzs(x)Experts(x)\text{MoE}(x)=\sum_{e=1}^{n_{\text{routed}}} z_e(x)\,\text{Expert}_e(x) + \sum_{s=1}^{n_{\text{shared}}} z_s(x)\,\text{Expert}_s(x)2, concatenate RMS normalized MoE(x)=e=1nroutedze(x)Experte(x)+s=1nsharedzs(x)Experts(x)\text{MoE}(x)=\sum_{e=1}^{n_{\text{routed}}} z_e(x)\,\text{Expert}_e(x) + \sum_{s=1}^{n_{\text{shared}}} z_s(x)\,\text{Expert}_s(x)3 and MoE(x)=e=1nroutedze(x)Experte(x)+s=1nsharedzs(x)Experts(x)\text{MoE}(x)=\sum_{e=1}^{n_{\text{routed}}} z_e(x)\,\text{Expert}_e(x) + \sum_{s=1}^{n_{\text{shared}}} z_s(x)\,\text{Expert}_s(x)4 embedding; project and pass through transformer block MoE(x)=e=1nroutedze(x)Experte(x)+s=1nsharedzs(x)Experts(x)\text{MoE}(x)=\sum_{e=1}^{n_{\text{routed}}} z_e(x)\,\text{Expert}_e(x) + \sum_{s=1}^{n_{\text{shared}}} z_s(x)\,\text{Expert}_s(x)5.
    • Compute MoE(x)=e=1nroutedze(x)Experte(x)+s=1nsharedzs(x)Experts(x)\text{MoE}(x)=\sum_{e=1}^{n_{\text{routed}}} z_e(x)\,\text{Expert}_e(x) + \sum_{s=1}^{n_{\text{shared}}} z_s(x)\,\text{Expert}_s(x)6 via the shared OutHead.
    • Aggregate loss terms: MoE(x)=e=1nroutedze(x)Experte(x)+s=1nsharedzs(x)Experts(x)\text{MoE}(x)=\sum_{e=1}^{n_{\text{routed}}} z_e(x)\,\text{Expert}_e(x) + \sum_{s=1}^{n_{\text{shared}}} z_s(x)\,\text{Expert}_s(x)7, MoE(x)=e=1nroutedze(x)Experte(x)+s=1nsharedzs(x)Experts(x)\text{MoE}(x)=\sum_{e=1}^{n_{\text{routed}}} z_e(x)\,\text{Expert}_e(x) + \sum_{s=1}^{n_{\text{shared}}} z_s(x)\,\text{Expert}_s(x)8, MoE(x)=e=1nroutedze(x)Experte(x)+s=1nsharedzs(x)Experts(x)\text{MoE}(x)=\sum_{e=1}^{n_{\text{routed}}} z_e(x)\,\text{Expert}_e(x) + \sum_{s=1}^{n_{\text{shared}}} z_s(x)\,\text{Expert}_s(x)9, h1:T0RT×dh^0_{1:T} \in \mathbb{R}^{T \times d}0, then compute overall h1:T0RT×dh^0_{1:T} \in \mathbb{R}^{T \times d}1 as above.
  • Backward pass: Compute h1:T0RT×dh^0_{1:T} \in \mathbb{R}^{T \times d}2 via autodiff and update h1:T0RT×dh^0_{1:T} \in \mathbb{R}^{T \times d}3 using the chosen optimizer.

All computational steps are efficiently vectorized over batch and sequence. MTP introduces only h1:T0RT×dh^0_{1:T} \in \mathbb{R}^{T \times d}4 small transformer submodules and one shared OutHead, which do not share parameters across h1:T0RT×dh^0_{1:T} \in \mathbb{R}^{T \times d}5.

5. Empirical Gains from MTP Distillation

Empirical results in SlimQwen manifest consistent improvements from the application of MTP distillation. In comparisons on the 23A2B model trained for 120B tokens, performance across different objectives is:

Configuration MMLU MMLU-Pro
NTP KD alone 74.16 50.97
NTP KD + LM 74.93 51.44
NTP KD + MTP KD 75.13 51.94
NTP KD + LM + MTP LM + MTP KD 75.67 51.19

This demonstrates that combining MTP with KD and next-token LM losses produces gains of h1:T0RT×dh^0_{1:T} \in \mathbb{R}^{T \times d}6–h1:T0RT×dh^0_{1:T} \in \mathbb{R}^{T \times d}7 point on knowledge-intensive benchmarks.

In speculative decoding—a protocol where the MTP head drafts multiple future tokens and the main backbone verifies them—MTP KD boosts multi-token acceptance rates. For instance, during GSM8K pretraining, two-token acceptance (“acc_2”) increases from h1:T0RT×dh^0_{1:T} \in \mathbb{R}^{T \times d}8 with MTP LM alone to h1:T0RT×dh^0_{1:T} \in \mathbb{R}^{T \times d}9 with MTP KD. In supervised fine-tuning (MTBench), four-token acceptance (“acc_4”) rises from DD0 to DD1 when using MTP KD.

6. Significance and Practical Considerations

By supervising both next-token and multiple future-token predictions through a lightweight MTP head, SlimQwen demonstrates an ability to more effectively tune its MoE experts. The architecture achieves improved language modeling quality and increased efficiency in multi-token generation, especially under speculative decoding protocols. These effects are realized with modest overhead: DD2 additional transformer submodules (non-shared across depths) and a shared output layer. The integration of MTP thus suggests practical value in scaling multi-token objectives for efficient LLM pretraining and inference (Tang et al., 9 May 2026).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to MoE Multi-Token Prediction (MTP) Layer.