Papers
Topics
Authors
Recent
2000 character limit reached

Joint State Prediction & Next-Token Objective

Updated 10 December 2025
  • Joint state-prediction + next-token objective is a composite pretraining strategy that fuses autoregressive generation with auxiliary state prediction to yield robust latent representations.
  • It utilizes techniques like masked particle modeling and next-latent prediction to improve long-range context coherence and enhance downstream classification and planning accuracy.
  • Empirical results demonstrate significant gains in generative fidelity, classification metrics, and representation compactness compared to models using next-token prediction alone.

A joint state-prediction + next-token objective refers to training sequence models—typically Transformer architectures—using both autoregressive next-token prediction and auxiliary objectives that require state or representation prediction. This composite pretraining strategy aims to encourage models to acquire both strong generative abilities (as from standard next-token prediction) and rich, contextually-aligned latent state representations that capture underlying structure and dependencies in the data, thereby improving both generative fidelity and performance on downstream discriminative tasks.

1. Problem Motivation and Conceptual Foundation

Standard next-token prediction (NTP), the backbone of modern generative models in language and physical sciences, optimizes cross-entropy on predicting tit_{i} from its known prefix. However, NTP alone does not necessarily encourage models to learn context representations with the properties of belief states or sufficient statistics for planning, control, or classification. This limitation motivates the inclusion of state-prediction objectives—masked prediction, next-latent, joint multi-token prediction, or context-level prediction—that compel the model to internally represent more than marginal short-range dependencies.

A joint objective, combining NTP with state-prediction, typically takes the form:

Ltotal=λgenLNTP+λstateLstate pred\mathcal{L}_{\mathrm{total}} = \lambda_{\mathrm{gen}}\mathcal{L}_\mathrm{NTP} + \lambda_\mathrm{state}\mathcal{L}_\mathrm{state~pred}

where LNTP\mathcal{L}_\mathrm{NTP} is standard generative loss and Lstate pred\mathcal{L}_\mathrm{state~pred} penalizes failures to reconstruct or predict masked tokens, latent states, or context aggregates. This construction is directly motivated in works such as "Enhancing next token prediction based pre-training for jet foundation models" (Birk et al., 3 Dec 2025), "Next-Latent Prediction Transformers Learn Compact World Models" (Teoh et al., 8 Nov 2025), "Efficient Joint Prediction of Multiple Future Tokens" (Ahn et al., 24 Mar 2025), and others.

2. Canonical Instantiations and Objective Forms

Multiple architectures instantiate this principle, each formalizing the state-prediction loss distinctively:

  • Masked/Masked Particle Modeling (MPM): A subset of input positions are masked. The model must reconstruct masked tokens given unmasked context, using either bidirectional or causal attention. The MPM loss is a cross-entropy computed only at masked positions. In jet physics, this is realized as masked token prediction over VQ-VAE-discretized particle features (Birk et al., 3 Dec 2025). Bidirectionality is found critical for downstream classification.
  • Next-Latent Prediction (NextLat): An auxiliary latent-dynamics model fψf_\psi is trained to predict the model's next hidden state given the current hidden state and the next token. The latent-prediction loss is a SmoothL1 or squared error between the ground-truth and predicted latents, with possible KL regularization to align predictive posteriors (Teoh et al., 8 Nov 2025). Inference relies only on the autoregressive transformer.
  • Joint Multi-Token Prediction (JTP): The model predicts multiple future tokens jointly, conditioning on the representation at the current position and teacher-forced ground-truth future tokens through a bottleneck ("Fetch" module), enforcing co-dependence (Ahn et al., 24 Mar 2025). The loss is a chain-ruled sum over marginal and conditional future predictions, increasing the density of gradient signals.
  • Chunked Next-Context Prediction: Each multi-token segment ("chunk") of an input sequence is pooled to a context embedding, which is then autoregressively predicted and injected into subsequent decoding, all using the same cross-entropy supervision. This design strengthens long-range context coherence (Dai et al., 23 Oct 2025).
  • Full-sequence Diffusion Forcing: The unified loss interpolates next-token teacher-forcing (zero-diffusion step) and multi-step denoising objectives over variable noise schedules (Chen et al., 1 Jul 2024). The model learns to generate individual or joint future segments with causal dependency, optimizing a per-token variational lower bound.

3. Combined Losses, Architectural Realizations, and Training Schedules

The principal architectural design and training protocol can be summarized as:

  • Hybrid Input Targeting: Training input is often a continuous or decoded pseudo-continuous vector (especially with VQ-VAE discretization), while the pretraining target is a discrete token (index into a codebook) (Birk et al., 3 Dec 2025). Fine-tuning for discriminative tasks uses full-resolution continuous input, bypassing quantization artifacts.
  • Loss Composition: The generative NTP loss averages cross-entropy over all tokens; the state-prediction loss targets masked, latent, or multi-token reconstructions, with task-dependent masking rates (e.g., 40% for jets (Birk et al., 3 Dec 2025)) and equal or tuned loss weights (usually λ=1\lambda=1).
  • Head Specialization: Distinct projection heads are used for each objective—often, MPM and joint heads add extra transformer blocks vs. a single linear unembedding for NTP-only, counteracting over-specialization to the generative pathway.
  • Bidirectional vs. Causal Attention: Empirical ablations confirm that, for masked reconstruction tasks, bidirectional attention in the masked head yields representations substantially better aligned with discriminative classification (Birk et al., 3 Dec 2025).
  • State-Dependent Conditioning: Models may incorporate global information (e.g., number of particles, trajectory length) using learned embeddings that modulate each token's encoding.

4. Empirical Findings and Comparative Ablations

Empirical evidence across applications consistently demonstrates that joint state-prediction + next-token objectives offer advantages over purely autoregressive or solely masked/latent modeling:

  • Generative Fidelity: Hybrid NTP+MPM pretraining perfectly preserves the quality of generated distributions over both individual elements and aggregates (e.g., jet-level statistics) (Birk et al., 3 Dec 2025).
  • Classification and Planning Performance: NTP+State models dramatically improve downstream classification accuracy/AUC, often by 5–15 points over NTP-only, and frequently match or approach the best masked-objective-only models without loss of generative power (Birk et al., 3 Dec 2025). In planning domains, JTP and NextLat solve synthetic structural tasks where NTP fails (Ahn et al., 24 Mar 2025, Teoh et al., 8 Nov 2025).
  • Representation Compactness: NextLat and JTP architectures empirically yield more compressed, lower-rank latent representations, robust detour handling and improved sequence compression (Teoh et al., 8 Nov 2025).
  • Gradient Flow and Training Signal: Joint objectives deliver denser and more informative gradient signals (O(T·D) distinct gradients per position in JTP for D step prediction), leading to better credit assignment and faster learning (Ahn et al., 24 Mar 2025).
  • Invariance to Overfitting: Second-to-last prediction and masked objectives are more robust to overfitting than next-token training, even with identical decoder backbones (Schneider, 23 Nov 2024).

A table summarizing reported empirical advantages is given below:

Task/Domain Joint NTP+State Improves Reference
Jet classification +5-15 AUC/acc (Birk et al., 3 Dec 2025)
World-modeling sequence compression (Teoh et al., 8 Nov 2025)
Long-range planning path accuracy (Ahn et al., 24 Mar 2025)
Language modeling perplexity, coherence (Dai et al., 23 Oct 2025)

5. Theoretical Rationale and Belief-State Guarantees

Theoretical results elucidate why adding state-prediction to NTP enforces "belief state" representations—internal summaries sufficient for future prediction. In particular:

  • Under joint NTP+latent (NextLat), if both the token and latent dynamics heads are Bayes optimal, the internal latent ztz_t becomes a sufficient statistic, i.e., a belief state btb_t such that p(xt+1:T∣zt)=p(xt+1:T∣x1:t)p(x_{t+1:T}|z_t)=p(x_{t+1:T}|x_{1:t}) (Teoh et al., 8 Nov 2025).
  • In JTP, the representation bottleneck (Fetch) forces ht−1h_{t-1} to encode all information necessary for predicting all DD future tokens jointly, enforcing co-dependence and short-horizon planning capability (Ahn et al., 24 Mar 2025).
  • Masked/objective-based approaches with bidirectional attention further align embeddings with global context and classification targets, enhancing transfer from simulation-free pretraining to downstream supervised tasks (Birk et al., 3 Dec 2025).

A plausible implication is that, whereas next-token-only models may overfit to local statistics or spurious correlations, joint objectives drive the model to internalize latent variables or histories required for robust, multi-step inference.

6. Representative Applications and Domain-Specific Instantiations

  • Particle Physics (Jet Foundation Models): Joint NTP+MPM in OmniJet-α+ achieves large gains over previous methods. Hybrid representations mitigate quantization artifacts in classification while preserving generation quality (Birk et al., 3 Dec 2025).
  • World Models and Planning: In synthetic path and navigation tasks, only JTP and NextLat yield high test performance at substantial graph depths and branch factors, highlighting the necessity of joint prediction for long-range coordination (Teoh et al., 8 Nov 2025, Ahn et al., 24 Mar 2025).
  • Language Modeling with Context Aggregation: ContextLM demonstrates that chunk-level context prediction pathways systematically lower perplexity and improve long-context attention allocation in GPT2 and Pythia (Dai et al., 23 Oct 2025).
  • Control and Embodied Policy Learning: Next-token prediction can be extended to multi-modal, continuous state-action sequences, enabling unified policy and dynamics modeling over diverse data, as in humanoid locomotion (Radosavovic et al., 29 Feb 2024).
  • Sequence Diffusion Models: Diffusion Forcing unifies next-token and full-sequence denoising, capturing the benefits of both autoregressive and diffusion frameworks for stable continuous sequence generation and planning (Chen et al., 1 Jul 2024).

7. Practical Recommendations and Future Directions

Multiple insights guide implementation and further research:

  • Input/Target Hybrids: Use full-resolution continuous input for discriminative downstream tasks, pseudo-continuous or decoded token-inverse features for generative pretraining, and discrete tokens as pretraining targets (Birk et al., 3 Dec 2025).
  • Bidirectional Attention: Prefer bidirectional attention in state-prediction heads for improved context encoding and transfer; causal-only masked heads typically fall short (Birk et al., 3 Dec 2025).
  • Loss Weights and Masking Rates: Equal loss weights (λ=1\lambda=1) suffice for most objectives; a 40% masking rate is optimal for masked particle modeling on jets (Birk et al., 3 Dec 2025). For JTP/NextLat, moderate auxiliary loss weights produce robust improvements (Teoh et al., 8 Nov 2025).
  • Architectural Modularity: Auxiliary heads for state prediction should incorporate a lightweight transformer or multi-layer structure to avoid over-specialization to the main generative path (Birk et al., 3 Dec 2025).
  • Scalability and Cost: The compute and memory costs of joint objectives are minimal—typically a 10–50% increase over NTP, depending on the roll-out or prediction horizon—and inference cost remains unchanged (Teoh et al., 8 Nov 2025, Ahn et al., 24 Mar 2025).
  • Extension to Fully Continuous Generation and Improved Tokenization: Promising future work includes learning protocols for continuous generation, improved VQ-VAE or quantization, and the integration of physics- or domain-specific priors while staying within a simulation-free, foundation-modeling paradigm (Birk et al., 3 Dec 2025).

This conceptual and empirical synthesis establishes joint state-prediction + next-token objectives as essential components for transferring the simulation-free generative strength of next-token pretraining to robust, context-sensitive downstream tasks across scientific and AI domains.

Whiteboard

Topic to Video (Beta)

Follow Topic

Get notified by email when new papers are published related to Joint State-Prediction + Next-Token Objective.