Reasoning with Latent Tokens in Diffusion Language Models
Abstract: Discrete diffusion models have recently become competitive with autoregressive models for language modeling, even outperforming them on reasoning tasks requiring planning and global coherence, but they require more computation at inference time. We trace this trade-off to a key mechanism: diffusion models are trained to jointly predict a distribution over all unknown tokens, including those that will not actually be decoded in the current step. Ablating this joint prediction yields faster inference but degrades performance, revealing that accurate prediction at the decoded position relies on joint reasoning about the distribution of undecoded tokens. We interpret these as latent tokens and introduce a method for modulating their number, demonstrating empirically that this enables a smooth tradeoff between inference speed and sample quality. Furthermore, we demonstrate that latent tokens can be introduced into autoregressive models through an auxiliary multi-token prediction objective, yielding substantial improvements on the same reasoning tasks where they have traditionally struggled. Our results suggest that latent tokens, while arising naturally in diffusion, represent a general mechanism for improving performance on tasks requiring global coherence or lookahead.
Paper Prompts
Sign up for free to create and run prompts on this paper using GPT-5.
Top Community Prompts
Explain it Like I'm 14
Overview
This paper compares two ways an AI model can fill in missing parts of a sequence (like completing blanks in a sentence or finishing steps in a timeline). The figure shows two approaches:
- MDM (Bidirectional): predicts several missing pieces together while looking at the whole sequence.
- SIDM/SCDM: predicts one missing piece at a time while only looking at the already known parts.
Think of it like solving a puzzle: one method places several pieces together by looking at the entire puzzle, and the other places one piece at a time by only looking at the pieces already placed.
Key Objectives
The main questions the paper aims to explore are:
- Should a model use information from both the left and right sides of a sequence (bidirectional) when predicting a missing part?
- Is it better to predict several missing parts at the same time, or to predict them one by one?
- What happens if the model ignores some unknown parts during the prediction step versus including them as “placeholders” it can pay attention to?
Methods and Approach
The figure represents a sequence of items, labeled as x with subscripts like x_{π1}, x_{π2}, … The π (pi) here just means there is some chosen order for the items in the sequence (you can think of it as a shuffled or decided order to fill in blanks).
- “History” boxes: these are the parts of the sequence we already know.
- “Target” box: this is the specific missing part we are trying to predict right now.
- “Latent” boxes: these are other missing parts that will be predicted too.
- “M” (mask): a placeholder that marks a blank the model needs to fill.
- Arrows (“attention”): show which inputs the model looks at to make a prediction.
Here is how each approach works:
- MDM (Bidirectional)
- Input: The model sees the known parts (history) and placeholders (M) for all unknown parts.
- Prediction: It predicts multiple missing parts together (for example, positions 4 through 8).
- Attention: When predicting the target (like position 4), the model looks at all positions, including the masked ones. This is “bidirectional” because it can use context from both earlier and later positions.
- SIDM/SCDM
- Input: The model sees the known history and the specific target mask, but ignores the other masked positions during the forward pass (they’re “omitted”).
- Prediction: It predicts one target at a time (like position 4).
- Attention: When predicting the target, the model looks only at the known history and the target mask. It does not use the other unknown positions yet.
A helpful analogy:
- MDM: You’re solving a crossword by looking at the whole grid, using clues from across and down together, and filling several blanks at once.
- SIDM/SCDM: You’re solving the crossword one blank at a time, using only the clues you’ve already solved, and ignoring the other blanks until later.
Main Findings and Why They Matter
From the diagram, the core takeaway is the difference in how information is used:
- MDM uses more context at once and predicts several blanks together. This can help the model keep the whole sequence consistent, because it learns relationships among multiple unknown parts at the same time.
- SIDM/SCDM focuses on one blank at a time and only uses already known information. This can make the process simpler and more direct, but it might miss helpful signals from the rest of the sequence until those parts are eventually filled.
Why this is important:
- Many tasks, like writing text, composing music, forecasting time series, or completing code, benefit from understanding both past and future context. Joint prediction can capture patterns that single-step methods might overlook.
- However, predicting one piece at a time can be easier to train and faster per step, and it may control errors better by focusing on a single target.
Implications and Potential Impact
- Better sequence completion: Using bidirectional context and joint prediction (MDM) can lead to more coherent and globally consistent outputs, especially when different parts of the sequence depend on each other.
- Simpler, step-by-step generation: The SIDM/SCDM approach can be more straightforward, potentially faster at each step, and easier to scale, which is useful when resources are limited or when a task benefits from careful, incremental updates.
- Design trade-offs: Developers and researchers can choose between joint, bidirectional prediction and stepwise, history-only prediction depending on their goals—prioritizing coherence and global patterns versus simplicity and speed.
In short, the paper highlights two different strategies for filling in blanks in sequences and helps explain when and why you might prefer one over the other.
Knowledge Gaps
Knowledge gaps, limitations, and open questions
The paper’s figure introduces two schemes (MDM Bidirectional vs. SIDM/SCDM) but leaves several methodological and empirical aspects unspecified. Future work could address the following:
- Missing formal definitions of MDM (bidirectional) and SIDM/SCDM: precise architectures, training objectives, and loss functions (which variables are predicted, when, and how losses are aggregated).
- Unspecified policy for choosing the permutation/order (fixed, random, learned) and its impact on performance, convergence, coverage of positions, and reproducibility.
- Ambiguity in the “jointly predicted” block (): whether outputs are predicted simultaneously or iteratively, how parameter sharing works, and how interdependencies are handled without circular conditioning.
- Mask token “M” representation: how “M” is encoded (single learned embedding, position-dependent, or structured), whether gradients flow through “M” inputs in MDM, and safeguards against information leakage via “M”.
- Attention masking details: in MDM all inputs attend to the target—what prevents masked positions from implicitly conveying target information; in SIDM/SCDM, how omission is implemented (pruning nodes vs. attention masks) and its side effects.
- Training–inference mismatch: how inference is performed when training uses joint prediction or omission (sequential vs. parallel decoding), and the extent of exposure bias or consistency issues.
- Computational and memory trade-offs: quantified cost of including all positions in the forward pass (MDM) vs. omitting many (SIDM/SCDM), with analyses of throughput, latency, and scaling to long sequences.
- Gradient signal allocation: which positions receive loss/gradients at each step for MDM vs. SIDM/SCDM, and the effect on learning dynamics and the quality of “history” vs. “latent” representations.
- Block size and curriculum: how many positions are jointly predicted, whether block size is tuned/adaptive, and the effect on accuracy, coherence, and error propagation.
- Evaluation scope: absence of empirical results (datasets, tasks, metrics) comparing MDM and SIDM/SCDM and ablations isolating each design choice (joint prediction, omission, attention scope).
- Robustness and scaling: behavior on long sequences, variable history sizes, heterogeneous token types/modalities, and training stability/convergence characteristics.
- Consistency across jointly predicted positions: mechanisms (constraints, iterative refinement) for enforcing global consistency within the predicted block and measuring cross-token coherence.
- Theoretical guarantees: probabilistic factorization or likelihood interpretation of the objective for bidirectional joint prediction, and conditions under which the method estimates a valid model.
- Order selection policy: how targets like are chosen across steps to ensure coverage, fairness, and sample efficiency (e.g., curriculum, adaptive ordering, learned schedulers).
- Notational and semantic clarity: precise meanings of “history,” “decode,” and “latent” categories and their mapping to concrete model components and computation graph.
Practical Applications
Overview
The figure contrasts two training/inference regimes for sequence diffusion models:
- MDM (Bidirectional): jointly predicts multiple masked positions using bidirectional attention over the entire sequence (including masked/future positions), enabling infilling and orderless generation.
- SIDM/SCDM (Sequential/Stepwise, Causal, Self-Conditioned): predicts one target at a time conditioned only on the available history; masked/future positions are omitted from the forward pass, aligning training with causal, latency-sensitive next-step prediction and reducing compute.
From these design choices flow practical applications in forecasting, imputation, editing/infilling, anomaly detection, and efficient deployment.
Immediate Applications
The following use cases can be piloted or deployed now with standard diffusion-model tooling and moderate engineering effort.
- Time-series gap-filling and next-step forecasting (energy, finance, healthcare, industrial IoT)
- What: Use SIDM to impute missing readings and forecast the next time step(s) causally, with calibrated uncertainty via sampling.
- Tools/products/workflows: PyTorch/TF implementations with masking; batched SIDM sampling microservices; pipelines that first impute missingness then forecast; integration with BI dashboards.
- Assumptions/dependencies: Accelerated samplers (e.g., DDIM/fewer steps) keep latency acceptable; masking reflects real missingness; no information leakage from future tokens.
- Latency-aware session-based next-item prediction (e-commerce, media, fintech)
- What: Use SIDM as a left-to-right next-event predictor for clicks, plays, or transactions; combine with rankers.
- Tools/products/workflows: Feature store → SIDM next-step scores → candidate generation/ranking; A/B testing in recommender stacks.
- Assumptions/dependencies: Discrete vocabulary of items is stable; inference budget accommodates diffusion steps; cold-start handled by metadata features.
- Anomaly detection via reconstruction/imputation error (operations, cybersecurity, manufacturing)
- What: Train MDM or SIDM to reconstruct sequences; deviations between observed and denoised predictions flag anomalies.
- Tools/products/workflows: Streaming scoring service computing negative log-likelihood or denoise residuals; alerting thresholds and feedback loop.
- Assumptions/dependencies: Proper calibration of uncertainty; robust baselining for nonstationarity and concept drift.
- Document/code span infilling and editing (software tooling, productivity)
- What: Use bidirectional MDM to jointly predict masked spans for fill-in-the-middle (FIM) code/text edits.
- Tools/products/workflows: IDE/editor plugins; internal productivity tools for structured templates; server-side GPU inference with batching.
- Assumptions/dependencies: Discrete diffusion quality on tokens/subwords is adequate; tokenizer alignment with syntax; may require domain-specific pretraining.
- Data cleaning pipelines for sensor networks (transportation, climate, telemetry)
- What: Apply SIDM to impute intermittent sensor failures and harmonize multivariate streams before analytics.
- Tools/products/workflows: ETL stage with masking → SIDM imputation → data quality reports; governance hooks for audit.
- Assumptions/dependencies: Coverage across regimes of operation; resilience to distribution shifts.
- Research benchmarking and ablations (academia, ML R&D)
- What: Compare bidirectional joint prediction (MDM) vs causal SIDM on accuracy/latency/compute across datasets and masking policies.
- Tools/products/workflows: Open-source baselines; evaluation suites with standardized masks and permutation schedules; reproducible training recipes.
- Assumptions/dependencies: Access to public datasets with missingness/causal constraints; fair sampler settings across methods.
- Compute/energy savings in training and inference (ML infrastructure)
- What: Omit masked/future tokens from the forward pass in SIDM to reduce FLOPs and memory during training and streaming inference.
- Tools/products/workflows: Attention masking, sparse computation kernels, compile-to-ONNX/TensorRT; cost reporting dashboards.
- Assumptions/dependencies: Efficient implementation of selective forward passes; careful batching to maintain high device utilization.
Long-Term Applications
These use cases are promising but depend on further research, scaling, or systems optimization (e.g., faster samplers, larger models, regulatory validation).
- On-device real-time generative predictors (edge IoT, wearables, robotics)
- What: Deploy SIDM for causal, low-latency next-step prediction and imputation directly on devices.
- Tools/products/workflows: Quantization/distillation of diffusion models; few-step/implicit samplers; hardware-aware attention pruning.
- Assumptions/dependencies: Sub-10 ms per step budgets; robustness to intermittent connectivity; memory-efficient self-conditioning.
- Diffusion-based language and code models with parallel-span editing (software, education)
- What: Use MDM’s joint, bidirectional span prediction to build large-scale interactive editors supporting multi-span edits and orderless generation.
- Tools/products/workflows: Editor agents coordinating mask proposals + joint decoding; speculative decoding across spans.
- Assumptions/dependencies: Scaling discrete diffusion to LLM-level quality; improved tokenization/noise schedules; competitive latency vs transformer decoders.
- Probabilistic clinical decision support and EHR imputation (healthcare)
- What: Combine SIDM imputation with short-horizon forecasting of vitals/labs with uncertainty for triage and monitoring.
- Tools/products/workflows: Bedside dashboards showing predicted ranges; conformal prediction for risk thresholds; model monitoring for drift.
- Assumptions/dependencies: Rigorous clinical validation; bias/fairness assessments; privacy and regulatory approvals.
- Power grid operations and market scenario generation (energy, finance)
- What: Generate multi-scenario short-term loads/prices using MDM (for full-context modeling) and SIDM (for streaming updates).
- Tools/products/workflows: Scenario orchestration feeding stochastic optimization; stress-testing and what-if analyses.
- Assumptions/dependencies: Acceptance by operators; auditable uncertainty; robustness under extreme events.
- Autonomous driving and human motion prediction (mobility, robotics)
- What: SIDM for causal multi-agent trajectory forecasting; MDM for map/context infilling.
- Tools/products/workflows: Planner integration with probabilistic forecasts; closed-loop simulation suites.
- Assumptions/dependencies: Safety case development; real-world generalization; compute budgets on embedded platforms.
- Privacy-preserving synthetic sequential data release (policy, enterprise data sharing)
- What: Use bidirectional MDM to learn full-sequence distributions and release synthetic EHR/transaction logs that preserve utility while protecting privacy.
- Tools/products/workflows: Differential privacy training; disclosure risk audits; data licensing frameworks.
- Assumptions/dependencies: Strong, quantifiable privacy guarantees; stakeholder trust and governance.
- Foundation sequential models across modalities (audio, video, multimodal streams)
- What: Extend masking/permutation schemes to continuous streams for online generation, editing, and imputation.
- Tools/products/workflows: Hybrid discrete–continuous diffusion; learnable masking policies; cross-modal conditioning.
- Assumptions/dependencies: Scalable tokenization for continuous signals; efficient multi-step sampling; high-throughput training.
- Parallelized decoding via joint prediction heads (software systems)
- What: Exploit MDM’s “jointly predicted” blocks to decode multiple positions per step, reducing wall-clock latency.
- Tools/products/workflows: Custom heads predicting token blocks; speculative verify-and-accept loops.
- Assumptions/dependencies: Accuracy–throughput trade-offs; architectural support for partial acceptance and rollback.
Across all applications, key dependencies include: appropriate masking/permutation schedules (π), access to high-quality sequential data, accelerated diffusion samplers, alignment between training and deployment context (causal vs bidirectional), and governance for safety, privacy, and fairness.
Glossary
- Bidirectional: Uses information from both past and future positions in a sequence during computation. Example: "MDM (Bidirectional)"
- Forward pass: The computation phase of a neural network that maps inputs to outputs without updating parameters. Example: "omit from forward pass"
- History: The known prefix (previous tokens) of a sequence used as context for prediction. Example: "history"
- Jointly predicted: Multiple outputs are predicted together in a single step as a joint set. Example: "jointly predicted"
- MDM: A model configuration (here, bidirectional) in which multiple masked positions can participate and be predicted jointly. Example: "MDM (Bidirectional)"
- Mask token (M): A placeholder indicating an unknown or to‑be‑predicted token in the input. Example: "M"
- Permutation index (π_i): An index denoting the position of an element under a permutation π, indicating an arbitrary decoding order. Example: ""
- SIDM/SCDM: Model configurations where only the history participates in the forward pass for predicting the target, omitting future masked positions. Example: "SIDM/SCDM"
- Target: The specific position/token the model is currently tasked with predicting. Example: "target"
Collections
Sign up for free to add this paper to one or more collections.