Masked Next Token Prediction (MNTP)
- Masked Next Token Prediction (MNTP) is a paradigm that generalizes next-token prediction by training models to predict masked tokens, combining bidirectional context with causal decoders.
- It leverages randomized masking and techniques like gated LoRA alongside lightweight sampler modules to enhance parallelism, efficiency, and resiliency in diverse data modalities.
- MNTP has demonstrated practical improvements in language, audio, and video tasks, achieving significant speedups and quality gains compared to standard next-token prediction.
Masked Next Token Prediction (MNTP) is a training and inference paradigm that generalizes the standard next-token prediction (NTP) objective of autoregressive models. It enables models to predict tokens at masked or held-out positions in a sequence, leveraging causal or unidirectional decoders. MNTP has been formulated and evaluated across multiple domains, including language modeling, audio generation, and video point tracking. The core appeal of MNTP is to combine the contextual richness of masked modeling (as in BERT) with the streaming and causality advantages of autoregressive architectures, often yielding improvements in efficiency, generalization, and parallelism.
1. Fundamental Concepts and Formalism
MNTP extends traditional next-token prediction by training models to predict not only the immediate next token, but also arbitrary “future” tokens, given a masked or dropped prefix. In formal terms, for a token sequence , a masking vector is applied, and the model predicts at a masked position given the observed tokens and, optionally, additional inputs such as prompts or original sequence context. The backbone model is typically a causal decoder (e.g., Transformer), which is adapted for this objective by modifying input masking, architectural routing, or through auxiliary modules (Yang et al., 14 Jul 2025, Samragh et al., 16 Jul 2025).
This paradigm subsumes standard NTP (prediction of from ) as a special case, and can interpolate between autoregressive and bidirectional training regimes. In MNTP, the set of masked tokens and their positions are randomized during training, augmenting the model's exposure to diverse contexts.
2. Architectures and Masking Mechanisms
Language and Multi-Token Generation
MNTP enables the joint prediction of multiple future tokens in LLMs, as introduced with a masked-input formulation (Samragh et al., 16 Jul 2025). Given a sequence , mask tokens are appended, yielding . The model, possibly equipped with gated Low-Rank Adaptation (LoRA), predicts the next tokens non-autoregressively at these positions. The “gated LoRA” technique applies trainable adapters only to the mask tokens, ensuring that native positions remain functionally identical to the pretrained LLM and preserving original next-token prediction performance.
A lightweight sampler module—a 2-layer MLP—further enhances coherence by conditioning each future token prediction on the hidden state and the previously generated token, producing distributions over the vocabulary.
Audio Generation
For audio, MNTP operates on continuous-valued tokens obtained via variational autoencoding (Yang et al., 14 Jul 2025). Instead of masking by substituting tokens, “drop-instead-of-mask” is used: masked positions are entirely removed from the context input to the Transformer decoder, significantly reducing sequence length and compute requirements. Each prediction is informed by the remaining, randomly thinned past, and the context is reconstructed for each target position with corresponding positional encodings.
Video Point Tracking
In TAPNext (Zholus et al., 8 Apr 2025), MNTP frames video point tracking as imputation over a set of spatiotemporal “point-tokens,” each corresponding to the coordinate of a query point at a specific timestep. Only the token at the initial observation is unmasked; others are initialized as learned . The recurrent causal transformer (TRecViT) backbone—comprising interleaved State Space Model (SSM) and Vision Transformer (ViT) blocks—propagates information strictly causally without access to future frames, enabling per-frame online inference.
3. Training Objectives and Loss Functions
MNTP leverages task- and modality-specific losses.
- Cross-Entropy Loss: For discrete tokens (language, video), cross-entropy is computed at each masked (predicted) position (Zholus et al., 8 Apr 2025, Schneider, 2024, Samragh et al., 16 Jul 2025).
- Diffusion-Based Loss: For continuous audio tokens, a diffusion MSE loss is used. For a given position , the loss is
where is the MLP diffusion head, define the diffusion schedule, and is the Transformer context. For MNTP, the loss is summed over all dropped (masked) positions (Yang et al., 14 Jul 2025).
- Combination and Auxiliary Losses: In video, a combined coordinate (Huber + cross-entropy) and visibility head loss are applied at every layer (intermediate supervision) (Zholus et al., 8 Apr 2025). For language, a latent consistency matching (LCM) loss aligns hidden state representations between standard autoregressive and masked branches (Samragh et al., 16 Jul 2025).
4. Empirical Results and Comparative Performance
MNTP has demonstrated advantages in multiple domains.
| Domain | Model | Baseline | MNTP Result | Relative Gain | Reference |
|---|---|---|---|---|---|
| Audio (AudioCaps) | AudioGen Base (discrete) | Frechet Audio Distance: 2.14 | 1.68 | 21% gain in FAD, 40% KL | (Yang et al., 14 Jul 2025) |
| Audio | AudioNTP Base (cont.) | FAD: 2.28 | 1.68 | 26% gain FAD, 10% KL | (Yang et al., 14 Jul 2025) |
| Language | Tulu3-8B SFT | AR LLM (NTP) | 2.3–5.3× speedup | No accuracy loss on NTP | (Samragh et al., 16 Jul 2025) |
| Video Tracking | CoTracker3 | Window latency ∼80 ms | 5.05 ms latency | >15× latency reduction | (Zholus et al., 8 Apr 2025) |
| Next-token acc. | GPT-2 (OpenWebText) | NTP: 42.40% | +0.06 pp AGR (MNTP) | Small but significant gains | (Schneider, 2024) |
MNTP in audio yields significant improvements in both Frechet Audio Distance and Kullback-Leibler divergence compared to discrete and even continuous-token NTP. In LLMs, MNTP unlocks greater inference speed by enabling simultaneous multi-token prediction, achieving ~5× faster generation in code and math tasks without loss of accuracy in standard NTP tasks, provided gated LoRA is used. In video tracking, minimal-latency online tracking is attained, with classical tracking behaviors emerging naturally from end-to-end MNTP training (Zholus et al., 8 Apr 2025).
5. Emergent Behaviors and Heuristics
MNTP-trained models frequently rediscover classical domain-specific heuristics as emergent properties. In the video tracking domain, attention maps demonstrate cost-volume-like attention (global matching of query and patch features), coordinate-readout heads (local focus near last prediction), and motion clustering (ranging and grouping points on rigid objects) (Zholus et al., 8 Apr 2025). These phenomena substantiate the claim that end-to-end MNTP incentives can induce priors such as motion smoothness and temporal continuity without explicit heuristics or submodules.
A plausible implication is that, in sufficiently expressive models and with ample data, standard tracking, matching, or context-propagation algorithms may be subsumed by MNTP-style objectives, rendering handcrafted modules obsolete.
6. Applications, Efficiency, and Trade-offs
- Parallelization and Latency: MNTP facilitates speculative decoding for LLMs, allowing for multiple tokens to be generated and verified in parallel (“linear” and “quadratic” speculative decoding). Acceptance rates scale with , yielding an average of 2.3–5.3× speedup across benchmarks (Samragh et al., 16 Jul 2025).
- Memory/Compute Overheads: Overheads are minimal with careful architectural choices. Gated LoRA adaptation adds parameters per layer (with ), and sampler MLPs remain small relative to the entire model (Samragh et al., 16 Jul 2025).
- Generalization and Robustness: Randomized masking during training simulates various context lengths and input-dropout, providing regularization effects and robustness to missing context, especially evident in audio modeling (Yang et al., 14 Jul 2025).
- Parameter Efficiency: Continuous-token MNTP models require fewer parameters than large discrete-vocabulary models, achieving state-of-the-art performance at smaller scale (Yang et al., 14 Jul 2025).
- Quality Preservation: Provided proper gating and modularization (e.g., gated LoRA), base model performance on next-token prediction is unaffected (Samragh et al., 16 Jul 2025).
7. Connections and Comparisons to Related Paradigms
MNTP bridges autoregressive next-token prediction (NTP) and bidirectional masked models such as Masked Auto-Regression (MAR). Unlike BERT-style randomness, MNTP can use structured or stochastic masking patterns and leverages causal decoders while still affording context flexibility and multi-position prediction.
Unlike classical BERT/MLM objectives, MNTP often attains higher throughput and parallelism (4× speedup via blockwise masking (Schneider, 2024)), with well-defined computational costs for masking ratio and block size. In speculative generation, MNTP leverages the inherent “latent knowledge of future tokens” in pretrained LLMs, contrasting earlier approaches that required major architectural modifications or incurred output quality degradation in the absence of careful gating (Samragh et al., 16 Jul 2025).
A plausible implication is that MNTP enables models to interpolate between unidirectional and bidirectional predictive regimes, allowing fine-tuning for target downstream requirements in generation speed, streaming, and fidelity.
Key references:
- "TAPNext: Tracking Any Point (TAP) as Next Token Prediction" (Zholus et al., 8 Apr 2025)
- "Generative Audio Language Modeling with Continuous-valued Tokens and Masked Next-Token Prediction" (Yang et al., 14 Jul 2025)
- "Your LLM Knows the Future: Uncovering Its Multi-Token Prediction Potential" (Samragh et al., 16 Jul 2025)
- "Improving Next Tokens via Second-to-Last Predictions with Generate and Refine" (Schneider, 2024)