Conditional MTP Projector (cMTPP)
- The Conditional MTP Projector (cMTPP) is a lightweight, parameter-efficient module that generates dynamic multi-token predictions via adaptive normalization and gated MLP transformations.
- It integrates with the Temporal Guidance framework to contrast 'amateur' logits with expert predictions, boosting generation quality while minimizing computational and memory overhead.
- Empirical evaluations demonstrate improved benchmark performance with reduced VRAM usage and latency compared to traditional dual-model contrastive decoding methods.
The Conditional MTP Projector (cMTPP) is a parameter-efficient, auxiliary projection module introduced for self-contrastive decoding in LLMs within the Temporal Guidance (TeGu) framework. Its core function is to enable dynamic multi-token prediction (MTP) for arbitrary temporal offsets without the overhead of multiple independent networks or full auxiliary models. By leveraging lightweight transformations and adaptive normalization of cached hidden states, cMTPP facilitates the construction of "amateur" predictions for contrastive learning in LLM decoding processes, yielding measurable improvements in generation quality while maintaining a compact computational and memory footprint (Zheng et al., 29 Jan 2026).
1. Role and Motivation within Temporal Guidance
Temporal Guidance (TeGu) implements self-contrastive decoding by contrasting predictions from an "expert" (standard forward LLM prediction) with those of an "amateur" model, conditioned on positions steps in the past. At each generation step , the scheme contrasts
- Expert logits:
- Amateur logits:
Guided scores are computed as:
However, native LLM architectures typically lack flexible MTP heads capable of generating predictions for arbitrary . The cMTPP module addresses this gap by providing a compact, shared mechanism for producing "amateur" logits given a cached hidden state and a step-offset , circumventing the need for extra full-sized models or multiple MTP heads. The output is passed through the frozen LLM (LM) head, ensuring compatibility and efficiency.
2. Architecture and Workflow
The cMTPP module transforms a cached last-layer hidden state and a temporal step-index into logits approximating . The principal components and data flow are:
- Input representations:
(1) : cached hidden state. (2) : integer offset.
- Step-ID Embedding:
- Adaptive LayerNorm (AdaLN):
Compute scale and bias via an MLP over
- Gated Feed-Forward Network ("SwiGLU" variant):
- Down-projection:
- LM head projection:
Parameters:
- , , (expansion factor )
- AdaLN and StepEmbed parameters are minor in scale
The full LM backbone and its output projection are frozen during cMTPP training.
3. Mathematical Representation
Let , the offset index. The transformations are:
- Adaptive normalization:
- Gated MLP with SwiGLU:
where
- Down-projection:
- Final logits via frozen head:
These "amateur" logits are used for contrastive adjustment of the expert logits as per TeGu.
4. Training Objective and Optimization
All parameters of cMTPP are trained while freezing the LLM backbone. For offsets at training position :
- Cross-entropy loss (CE):
- KL-Distillation loss (KD):
- Total loss:
controls CE/KD mixing (e.g., ); KD can use temperature .
Ablation experiments indicate that combining CE and KD objectives stabilizes and enhances benchmark performance, whereas CE-only objectives yield degradation or instability.
5. Efficiency: Parameter, Memory, and Computation Analysis
A summary of cMTPP’s computational cost and memory usage:
- Parameter count:
collectively: parameters (). For , overhead is million params ( for 8B models).
- FLOPs per inference step:
Three matrix multiplications: FLOPs per token, substantially less than a full second forward through the backbone.
- Additional memory and latency:
On Qwen3-8B, cMTPP increases VRAM usage from 17.72 GB (greedy) to 19.72 GB (), and time by just 2%.
| Decoding Mode | VRAM (GB) | Latency (×) |
|---|---|---|
| Greedy | 17.72 | 1.0 |
| Standard CD (1.7B) | 23.11 | 1.2 |
| DoLa | 19.22 | 1.04 |
| TeGu + cMTPP | 19.72 | 1.02 |
This overhead is substantially lower than that of dual-model contrastive decoding and close to DoLa.
6. Empirical Evaluation and Ablations
Main Results
- On Qwen3-1.7B (α=0.2): GSM8K: , IFEval:
- On Qwen3-8B (α=0.5): Math500: , IFEval:
Key Ablation Insights
- CE-only training leads to instability; the inclusion of KL-distillation is critical for robust task performance across all evaluated benchmarks.
- Single-step offset () as the amateur outperforms higher or mixtures: larger dilute the contrastive effect and reduce accuracy.
- Optimal guidance strength depends on model size: smaller models peak at , while larger models tolerate up to .
- Amateur logits generated by cMTPP demonstrate elevated entropy relative to the expert head, confirming their function as high-uncertainty, "amateur" projections.
7. Implementation Considerations and Pseudocode
Principal hyperparameters include expansion ratio , loss weights (CE: 0.3, KD: 0.7), KD temperature , optimizer AdamW (peak LR 2e–4), cosine schedule, and warmup.
cMTPP is used during inference as follows:
1 2 3 4 5 6 7 8 9 10 |
h_cur = LM.encode_last_hidden(x_{<t})
logits_exp = Softmax( W_LM @ h_cur )
h_old = cache.pop_oldest_if_needed()
z = cMTPP_forward(h_old, k)
logits_amt = Softmax( W_LM @ z )
guided_logits = logits_exp.log() + α * (logits_exp.log() - logits_amt.log())
next_token = sample_or_argmax(guided_logits) |
The cMTPP forward step is:
1 2 3 4 5 6 7 8 9 10 11 12 |
def cMTPP_forward(h, k): e_k = StepEmbed[k] # shape (d,) γ_k, β_k = AdaLN_MLP(e_k) # each (d,) h_norm = RMSNorm(h) # (d,) h_tilde = γ_k * h_norm + β_k # (d,) u1 = GELU(W1 @ h_tilde + b1) # (r d,) u2 = W2 @ h_tilde + b2 # (r d,) u = u1 * u2 # elementwise, (r d,) z = W3 @ u + b3 # (d,) return z |
All LLM backbone parameters remain frozen throughout cMTPP training; only cMTPP parameters are updated (Zheng et al., 29 Jan 2026).