Papers
Topics
Authors
Recent
Search
2000 character limit reached

Conditional MTP Projector (cMTPP)

Updated 5 February 2026
  • The Conditional MTP Projector (cMTPP) is a lightweight, parameter-efficient module that generates dynamic multi-token predictions via adaptive normalization and gated MLP transformations.
  • It integrates with the Temporal Guidance framework to contrast 'amateur' logits with expert predictions, boosting generation quality while minimizing computational and memory overhead.
  • Empirical evaluations demonstrate improved benchmark performance with reduced VRAM usage and latency compared to traditional dual-model contrastive decoding methods.

The Conditional MTP Projector (cMTPP) is a parameter-efficient, auxiliary projection module introduced for self-contrastive decoding in LLMs within the Temporal Guidance (TeGu) framework. Its core function is to enable dynamic multi-token prediction (MTP) for arbitrary temporal offsets kk without the overhead of multiple independent networks or full auxiliary models. By leveraging lightweight transformations and adaptive normalization of cached hidden states, cMTPP facilitates the construction of "amateur" predictions for contrastive learning in LLM decoding processes, yielding measurable improvements in generation quality while maintaining a compact computational and memory footprint (Zheng et al., 29 Jan 2026).

1. Role and Motivation within Temporal Guidance

Temporal Guidance (TeGu) implements self-contrastive decoding by contrasting predictions from an "expert" (standard forward LLM prediction) with those of an "amateur" model, conditioned on positions kk steps in the past. At each generation step tt, the scheme contrasts

  • Expert logits: logPexp(xtx<t)\log P_\mathrm{exp}(x_t \mid x_{<t})
  • Amateur logits: logPamt(xtx<tk)\log P_\mathrm{amt}(x_t \mid x_{<t-k})

Guided scores are computed as:

V(xt)=logPexp(xt)+α[logPexp(xt)logPamt(xt)]V(x_t) = \log P_\mathrm{exp}(x_t) + \alpha \left[ \log P_\mathrm{exp}(x_t) - \log P_\mathrm{amt}(x_t) \right]

However, native LLM architectures typically lack flexible MTP heads capable of generating predictions for arbitrary kk. The cMTPP module addresses this gap by providing a compact, shared mechanism for producing "amateur" logits given a cached hidden state ht1kh_{t-1-k} and a step-offset kk, circumventing the need for extra full-sized models or multiple MTP heads. The output is passed through the frozen LLM (LM) head, ensuring compatibility and efficiency.

2. Architecture and Workflow

The cMTPP module transforms a cached last-layer hidden state and a temporal step-index into logits approximating P(xtx<tk)P(x_t \mid x_{<t-k}). The principal components and data flow are:

  • Input representations:

(1) h=ht1kRdh = h_{t-1-k} \in \mathbb{R}^d: cached hidden state. (2) kk: integer offset.

  • Step-ID Embedding:

ek=StepEmbed(k)Rde_k = \mathrm{StepEmbed}(k) \in \mathbb{R}^d

  • Adaptive LayerNorm (AdaLN):

Compute scale γk\gamma_k and bias βk\beta_k via an MLP over eke_k

h~=γkRMSNorm(h)+βk\tilde{h} = \gamma_k \odot \mathrm{RMSNorm}(h) + \beta_k

  • Gated Feed-Forward Network ("SwiGLU" variant):

u1=GELU(W1h~+b1)u_1 = \mathrm{GELU}(W_1 \tilde{h} + b_1) u2=W2h~+b2u_2 = W_2 \tilde{h} + b_2

u=u1u2u = u_1 \odot u_2

  • Down-projection:

z=W3u+b3z = W_3 u + b_3

  • LM head projection:

logitsamt=WLMz\mathrm{logits}_\mathrm{amt} = W_\mathrm{LM} z

Parameters:

  • W1,W2Re×dW_1, W_2 \in \mathbb{R}^{e \times d}, W3Rd×eW_3 \in \mathbb{R}^{d \times e}, e=rde = r \cdot d (expansion factor r=2.7r=2.7)
  • AdaLN and StepEmbed parameters are minor in scale

The full LM backbone and its output projection WLMW_\mathrm{LM} are frozen during cMTPP training.

3. Mathematical Representation

Let h=ht1kh = h_{t-1-k}, kk the offset index. The transformations are:

  1. Adaptive normalization:

h~=AdaLN(h,k)=γkRMSNorm(h)+βk\tilde{h} = \mathrm{AdaLN}(h, k) = \gamma_k \odot \mathrm{RMSNorm}(h) + \beta_k

  1. Gated MLP with SwiGLU:

u=SwiGLU(W1h~+b1)(W2h~+b2)u = \mathrm{SwiGLU}(W_1 \tilde{h} + b_1) \odot (W_2 \tilde{h} + b_2)

where SwiGLU(x)=GELU(x)x\mathrm{SwiGLU}(x) = \mathrm{GELU}(x) \odot x

  1. Down-projection:

z=W3u+b3z = W_3 u + b_3

  1. Final logits via frozen head:

Pamt(xtx<tk)=Softmax(WLMz)P_\mathrm{amt}(x_t \mid x_{<t-k}) = \mathrm{Softmax}(W_\mathrm{LM} z)

These "amateur" logits are used for contrastive adjustment of the expert logits as per TeGu.

4. Training Objective and Optimization

All parameters of cMTPP are trained while freezing the LLM backbone. For offsets kKk \in K at training position tt:

  • Cross-entropy loss (CE):

LCE(t,k)=logPamt(xtx<tk)L_\mathrm{CE}^{(t,k)} = -\log P_\mathrm{amt}(x_t \mid x_{<t-k})

  • KL-Distillation loss (KD):

LKD(t,k)=KL[Pexp(x<t)    Pamt(x<tk)]L_\mathrm{KD}^{(t,k)} = \mathrm{KL}\left[ P_\mathrm{exp}(\cdot \mid x_{<t})\;\|\;P_\mathrm{amt}(\cdot \mid x_{<t-k}) \right]

  • Total loss:

L=tkK[(1Ak)LCE(t,k)+AkLKD(t,k)]\mathcal{L} = \sum_t \sum_{k \in K} \left[ (1 - A_k) L_\mathrm{CE}^{(t,k)} + A_k L_\mathrm{KD}^{(t,k)} \right]

AkA_k controls CE/KD mixing (e.g., Ak=0.7A_k = 0.7); KD can use temperature T>1T > 1.

Ablation experiments indicate that combining CE and KD objectives stabilizes and enhances benchmark performance, whereas CE-only objectives yield degradation or instability.

5. Efficiency: Parameter, Memory, and Computation Analysis

A summary of cMTPP’s computational cost and memory usage:

  • Parameter count:

W1,W2,W3W_1, W_2, W_3 collectively: 3rd2\approx 3 r d^2 parameters (r=2.7r=2.7). For d=4096d=4096, overhead is 135\sim135 million params (<2%<2\% for 8B models).

  • FLOPs per inference step:

Three matrix multiplications: 3rd23 r d^2 FLOPs per token, substantially less than a full second forward through the backbone.

  • Additional memory and latency:

On Qwen3-8B, cMTPP increases VRAM usage from 17.72 GB (greedy) to 19.72 GB (11%\sim11\%), and time by just 2%.

Decoding Mode VRAM (GB) Latency (×)
Greedy 17.72 1.0
Standard CD (1.7B) 23.11 1.2
DoLa 19.22 1.04
TeGu + cMTPP 19.72 1.02

This overhead is substantially lower than that of dual-model contrastive decoding and close to DoLa.

6. Empirical Evaluation and Ablations

Main Results

  • On Qwen3-1.7B (α=0.2): GSM8K: 72.48%75.51%72.48\%\rightarrow75.51\%, IFEval: 15.16%26.99%15.16\%\rightarrow26.99\%
  • On Qwen3-8B (α=0.5): Math500: 20.40%24.20%20.40\%\rightarrow24.20\%, IFEval: 24.77%34.20%24.77\%\rightarrow34.20\%

Key Ablation Insights

  • CE-only training leads to instability; the inclusion of KL-distillation is critical for robust task performance across all evaluated benchmarks.
  • Single-step offset (k=1k=1) as the amateur outperforms higher kk or mixtures: larger kk dilute the contrastive effect and reduce accuracy.
  • Optimal guidance strength α\alpha depends on model size: smaller models peak at α0.2\alpha\approx0.2, while larger models tolerate up to α0.7\alpha\approx0.7.
  • Amateur logits generated by cMTPP demonstrate elevated entropy relative to the expert head, confirming their function as high-uncertainty, "amateur" projections.

7. Implementation Considerations and Pseudocode

Principal hyperparameters include expansion ratio r=2.7r=2.7, loss weights (CE: 0.3, KD: 0.7), KD temperature T=2.0T=2.0, optimizer AdamW (peak LR 2e–4), cosine schedule, and 5%5\% warmup.

cMTPP is used during inference as follows:

1
2
3
4
5
6
7
8
9
10
h_cur = LM.encode_last_hidden(x_{<t})
logits_exp = Softmax( W_LM @ h_cur )

h_old = cache.pop_oldest_if_needed()
z = cMTPP_forward(h_old, k)

logits_amt = Softmax( W_LM @ z )

guided_logits = logits_exp.log() + α * (logits_exp.log() - logits_amt.log())
next_token = sample_or_argmax(guided_logits)

The cMTPP forward step is:

1
2
3
4
5
6
7
8
9
10
11
12
def cMTPP_forward(h, k):
    e_k = StepEmbed[k]                  # shape (d,)
    γ_k, β_k = AdaLN_MLP(e_k)           # each (d,)
    h_norm = RMSNorm(h)                 # (d,)
    h_tilde = γ_k * h_norm + β_k        # (d,)

    u1 = GELU(W1 @ h_tilde + b1)        # (r d,)
    u2 =      W2 @ h_tilde + b2         # (r d,)
    u  = u1 * u2                        # elementwise, (r d,)

    z  = W3 @ u + b3                    # (d,)
    return z

All LLM backbone parameters remain frozen throughout cMTPP training; only cMTPP parameters are updated (Zheng et al., 29 Jan 2026).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Conditional MTP Projector (cMTPP).