- The paper introduces masked-input formulation and gated LoRA to efficiently enable multi-token prediction in autoregressive LLMs.
- It demonstrates up to 5.2× inference speedup in code and math tasks while preserving next-token prediction quality.
- The lightweight sampler head and quadratic decoding strategy yield high acceptance rates with minimal computational overhead.
Multi-Token Prediction in Autoregressive LLMs: Architecture, Training, and Inference Acceleration
The paper "Your LLM Knows the Future: Uncovering Its Multi-Token Prediction Potential" (2507.11851) presents a systematic investigation into the latent capacity of standard autoregressive LLMs to predict multiple future tokens in a single inference step. The authors propose a practical framework that enables efficient multi-token generation with minimal architectural changes and without sacrificing the quality of next-token prediction. This work is situated at the intersection of speculative decoding, multi-token prediction (MTP), and parameter-efficient fine-tuning.
Core Contributions
The paper introduces several key innovations:
- Masked-Input Formulation: The model is augmented to accept k mask tokens appended to the input, prompting it to predict k future tokens jointly from a shared prefix.
- Gated LoRA Adaptation: Fine-tuning is performed using a gated version of Low-Rank Adaptation (LoRA), where LoRA parameters are only activated for mask (MTP) tokens, preserving the original model's next-token prediction (NTP) behavior.
- Lightweight Sampler Module: A two-layer MLP sampler head is introduced to generate coherent token sequences from the predicted future token distributions, conditioning each token on the previous sampled token.
- Auxiliary Consistency Loss: A latent consistency loss (LCM) is used to align the hidden representations of MTP and NTP tokens, improving the acceptance rate of speculative decoding.
- Quadratic Decoding Strategy: The authors propose a quadratic decoding algorithm that interleaves mask tokens within speculative tokens, ensuring a consistent supply of speculative candidates and higher acceptance rates compared to linear decoding.
Methodological Details
During fine-tuning, the model is trained to predict both the next token and k future tokens at each position by inserting mask tokens after each NTP token. This is implemented efficiently by processing all such masked sequences in parallel within a single batch. The attention mask is carefully constructed so that NTP tokens attend only to previous NTP tokens, while MTP tokens attend to both previous NTP and MTP tokens within the same block, but not to earlier MTP blocks.
Gated LoRA
Gated LoRA ensures that only the MTP tokens are affected by the LoRA adaptation, while NTP tokens retain the original model's output. This is achieved by a binary gating function applied at each position, which is deterministically set based on token type. This approach avoids the degradation of NTP performance observed with standard LoRA fine-tuning and eliminates the need for complex training recipes or loss reweighting.
Sampler Head
The sampler head is a lightweight MLP that takes as input the concatenation of the current masked token's hidden state and the embedding of the previously sampled token. This design allows the model to generate more coherent multi-token outputs, especially as the number of predicted tokens increases.
Speculative Decoding
The framework supports both linear and quadratic speculative decoding. In linear decoding, speculative tokens are verified sequentially, and only fully verified blocks are accepted. Quadratic decoding interleaves mask tokens within speculative outputs, guaranteeing that k speculative tokens are always available for verification, thus improving throughput at the cost of slightly increased sequence length.
Training Loss
The total loss is a sum of cross-entropy losses for both the base and sampler heads, plus the latent consistency loss. The LCM loss is only applied to MTP tokens and encourages their hidden representations to match those of the corresponding NTP tokens, acting as a form of self-distillation.
Experimental Results
Experiments are conducted on the Tulu3-8B SFT model, a LLaMA-3 derivative, fine-tuned with k=8 mask tokens. The main findings are:
- Speedup: The method achieves up to 5.2× speedup in code and math generation, and 2.5× speedup in general chat and knowledge tasks, with no loss in output quality.
- Quality Preservation: Gated LoRA maintains zero-shot accuracy on ARC-Challenge and preserves NTP loss, in contrast to standard LoRA which degrades NTP performance.
- Ablations: The combination of quadratic decoding, the sampler head, and LCM loss yields the highest acceptance rates. Notably, even LoRA ranks as low as 1 or 4 provide substantial speedup, indicating that the base model already encodes significant information about future tokens.
- Resource Efficiency: The memory overhead of LoRA and the sampler head is minimal, especially at low ranks, making the approach suitable for deployment in resource-constrained environments.
Implications and Future Directions
This work demonstrates that standard autoregressive LLMs possess substantial implicit knowledge of future tokens, which can be efficiently harnessed for multi-token prediction with minimal fine-tuning and architectural changes. The approach is compatible with existing speculative decoding frameworks and can be integrated with more advanced tree-based or hardware-aware speculative algorithms.
Practical implications include:
- Inference Acceleration: The method provides a straightforward path to accelerate LLM inference in production without retraining from scratch or sacrificing model quality.
- Parameter-Efficient Adaptation: Gated LoRA enables selective adaptation for new tasks or domains while preserving base model performance, facilitating continual learning and domain adaptation.
- Hardware Efficiency: The lightweight nature of the sampler and LoRA modules makes the approach attractive for edge deployment and large-scale serving.
Theoretical implications and future research avenues:
- Pretraining with MTP: Investigating the impact of multi-token prediction objectives during pretraining could further enhance the model's ability to anticipate future tokens.
- Diffusion-Based Generation: Exploring the integration of diffusion models with MTP may yield new trade-offs between fully autoregressive and non-autoregressive generation.
- Generalization to Other Modalities: The masked-input and gated adaptation strategies may be applicable to multi-modal or non-textual sequence generation tasks.
Conclusion
The paper provides a comprehensive framework for unlocking the multi-token prediction potential of autoregressive LLMs, achieving significant inference speedups with minimal overhead and no loss in quality. The proposed techniques—masked-input fine-tuning, gated LoRA, lightweight sampling, and quadratic decoding—collectively establish a new baseline for efficient LLM inference and open several promising directions for future research in both model architecture and training methodology.