Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
157 tokens/sec
GPT-4o
8 tokens/sec
Gemini 2.5 Pro Pro
46 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Train for the Worst, Plan for the Best: Understanding Token Ordering in Masked Diffusions (2502.06768v2)

Published 10 Feb 2025 in cs.LG

Abstract: In recent years, masked diffusion models (MDMs) have emerged as a promising alternative approach for generative modeling over discrete domains. Compared to autoregressive models (ARMs), MDMs trade off complexity at training time with flexibility at inference time. At training time, they must learn to solve an exponentially large number of infilling problems, but at inference time, they can decode tokens in essentially arbitrary order. In this work, we closely examine these two competing effects. On the training front, we theoretically and empirically demonstrate that MDMs indeed train on computationally intractable subproblems compared to their autoregressive counterparts. On the inference front, we show that a suitable strategy for adaptively choosing the token decoding order significantly enhances the capabilities of MDMs, allowing them to sidestep hard subproblems. On logic puzzles like Sudoku, we show that adaptive inference can boost solving accuracy in pretrained MDMs from $<7$% to $\approx 90$%, even outperforming ARMs with $7\times$ as many parameters and that were explicitly trained via teacher forcing to learn the right order of decoding.

Summary

  • The paper finds that Masked Diffusion Models (MDMs) are trained on a diverse, order-agnostic set of difficult infilling problems, contrasting with Autoregressive Models (ARMs) fixed-order training.
  • Despite training on challenging tasks, MDMs achieve significant performance gains, particularly on logic puzzles like Sudoku and Zebra, by employing adaptive inference strategies to intelligently select token decoding order.
  • Adaptive MDM inference can outperform larger ARMs trained with explicit ordering information and shows better generalization to harder versions of logic puzzles, suggesting robustness from diverse training.

This paper, "Train for the Worst, Plan for the Best: Understanding Token Ordering in Masked Diffusions" (2502.06768), investigates the trade-offs between Masked Diffusion Models (MDMs) and Autoregressive Models (ARMs) for generative modeling in discrete domains. It highlights that MDMs face more complex training due to their order-agnostic nature but offer greater flexibility at inference time.

The core argument is that while MDMs learn to solve an exponentially large number of challenging infilling subproblems during training ("train for the worst"), their inference-time flexibility, particularly with adaptive token decoding strategies, allows them to sidestep these hard subproblems and achieve superior performance ("plan for the best").

Masked Diffusion Models (MDMs)

MDMs operate through a forward and reverse process:

  • Forward Process: Gradually introduces noise by independently masking tokens in a sequence x0x_0 to a special "mask" token (denoted as 0). The probability of a token x0ix_0^i being masked at noise level tt is 1αt1-\alpha_t, where αt\alpha_t is a noise schedule (α01\alpha_0 \approx 1, α10\alpha_1 \approx 0).

    qt0(xtix0i)=Cat(αtex0i+(1αt)e0)q_{t|0}(x_t^i \mid x_0^i) = \mathrm{Cat}\bigl(\alpha_t \mathbf{e}_{x_0^i} + (1-\alpha_t)\mathbf{e}_{0} \bigr)

  • Reverse Process: Learns to denoise the masked sequence. A denoising network pθ(xt,t)p_\theta(x_t, t) predicts the original tokens x0x_0 given the noisy sequence xtx_t and noise level tt. The model is trained by minimizing an ELBO-based loss:

    $\mathcal{L}_\theta = \int_{0}^1 \frac{\alpha_t'}{1-\alpha_t} \displaystyle \mathop{\mathbb{E}_{ \substack{x_0 \sim p_{\rm data} \ x_t \sim q_{t|0}(\cdot | x_0)} } \left[\delta_{x_t,0} \mathbf{e}_{x_0}^\intercal \log p_\theta(x_t,t) \right] dt.$

    In practice, a time-embedding-free architecture pθ(xt)p_\theta(x_t) is often used.

The paper reformulates the MDM loss (Proposition \ref{prop:mdm_loss}) to show it's equivalent to learning all possible infilling problems:

Lθ=1LM[L],iM1(L1M1)Ex0pdata[logpθ(x0ix0[M])].\mathcal{L}_\theta = -\frac{1}{L}\sum_{ M\subseteq [L],i \in M}\frac{1}{\binom{L-1}{|M|-1}} \displaystyle \mathop{\mathbb{E}}_{x_0 \sim p_{\rm data}} [ \log p_\theta(x^i_0 | x_0[M]) ].

This "order-agnostic" training contrasts with ARMs, which learn a fixed left-to-right prediction order.

Vanilla MDM inference involves:

  1. Sampling a set of masked tokens S\mathcal{S} to unmask based on the noise schedule.
  2. For each iSi \in \mathcal{S}, sampling xsipθ(xixt)x_s^i \sim p_\theta(x^i | x_t).

MDMs Train on Hard Problems

The paper argues that many subproblems MDMs encounter during training are computationally intractable.

Theoretical Evidence (Latents-and-Observations Distributions):

The authors introduce "Latents-and-Observations" (L{content}O) distributions, where some tokens are "latents" (random seeds) and others are "observations" (functions of latents). For L{content}O distributions:

  • Order-aware training (e.g., predicting latents then observations) is computationally tractable.
  • Order-agnostic training (like MDMs) can encounter computationally hard masking problems.

One example is Sparse Predicate Observations (Example \ref{example:csp}): Latent tokens are sampled, and observation tokens are predicates (e.g., NAE - Not-All-Equal) applied to subsets of latent tokens. Proposition \ref{prop:csp} shows that for certain masking fractions α\alpha, predicting masked latent tokens given some unmasked observations becomes computationally hard, relating to the hardness of planted Constraint Satisfaction Problems (CSPs). The hardness thresholds (DKSD_{\rm KS}, DcondD_{\rm cond}) from statistical physics are used to define these hard regimes. Figure \ref{fig:csp} illustrates this concept for a planted NAE-SAT problem, showing a phase transition where belief propagation fails.

Empirical Evidence (Text Data):

To demonstrate hardness on real-world data, the authors use π\pi-learners, which predict tokens in an order defined by a permutation π\pi.

logpθ(x0)=i=0L1logpθ(x0π(i)x0[π{i,,L1}]).\log p_{\theta}(x_0) = \sum_{i=0}^{L-1} \log p_\theta \bigl( x_0^{\pi(i)} \Big| x_0 [\pi\{i,\ldots,L-1\}] \bigr).

The MDM loss is equivalent to averaging the loss of π\pi-learners over all permutations π\pi. Experiments on the Slimpajama dataset show that as π\pi deviates from the natural left-to-right order (identity permutation), the likelihood achieved by the π\pi-learner worsens (Figure \ref{fig:scaling_laws}, left). This suggests that many masking orders are inherently harder than left-to-right for text.

Error Imbalance:

The paper shows empirically that MDMs exhibit performance imbalance: they achieve lower error on easier subproblems and higher error on harder ones.

  • For L{content}O-NAE-SAT, MDMs perform better at predicting observation tokens than latent tokens (Figure \ref{fig:scaling_laws}, right).
  • For text, different permutations π\pi lead to varying validation losses, indicating error imbalance across different ordering tasks.

MDMs Can Plan Around Hard Problems with Adaptive Inference

Despite training on hard problems, MDMs can avoid them during inference using adaptive strategies. The key insight is that an ideal MDM can decode tokens in any order. While practical MDMs aren't ideal, their logits often contain enough information to guide the decoding order.

Adaptive MDM Inference:

Instead of randomly selecting which tokens to unmask, an oracle F(θ,xt)\mathcal{F}(\theta, x_t) strategically chooses the set S\mathcal{S}:

  1. Sample S=F(θ,xt){ixti=0}\mathcal{S} = \mathcal{F}\left( \theta, x_t \right) \subseteq \{i \mid x_t^i = 0\}.
  2. For each iSi \in \mathcal{S}, sample xsipθ(xixt)x_s^i \sim p_\theta(x^i | x_t).

Oracle Design Strategies:

Two main strategies are proposed to select tokens based on model certainty:

  1. Top probability: Select KK tokens where the model assigns the highest maximum probability to any single token value for that position: F(θ,xt)=Top K(maxjpθ(xi=jxt))\mathcal{F}(\theta, x_t) = \text{Top } K \left(\max_{j} p_\theta(x^i = j | x_t) \right).
  2. Top-KK probability margin: Select KK tokens based on the largest difference between the probabilities of the two most likely token values: F(θ,xt)=Top K(pθ(xi=j1xt)pθ(xi=j2xt))\mathcal{F}(\theta, x_t) = \text{Top } K \left(| p_\theta(x^i = j_1 | x_t) - p_\theta(x^i = j_2 | x_t) | \right). This is more robust when the model is uncertain between a few high-probability options.

Experimental Validation of Adaptive Inference:

  • L{content}O-NAE-SAT: Adaptive inference (Top-KK probability margin) significantly improves accuracy in predicting observation tokens compared to vanilla MDM inference (Table \ref{tab:csp_sampler}). For instance, with (N,P)=(50,250)(N,P) = (50,250), accuracy jumps from 67.94% to 90.01%.
  • Text Data: Adaptive MDM inference (with temperature added to the oracle for diversity) reduces generative perplexity (GenPPL) on text generation tasks while maintaining entropy, evaluated using a LLaMA-7B model (Figure \ref{fig:genppl}).
  • Logic Puzzles (Sudoku and Zebra): These tasks have sequence-dependent reasoning paths, making them challenging for fixed-order ARMs.
    • Sudoku (Table \ref{tab:sudoku-results}):
    • Vanilla MDM (6M params): 6.88% accuracy.
    • MDM with Top-KK probability: 18.51% accuracy.
    • MDM with Top-KK probability margin: 89.49% accuracy.
    • This outperforms ARMs (42M params) without ordering (9.73%) and even ARMs explicitly trained with teacher forcing for the correct decoding order (87.18%).
    • Zebra Puzzles (Table \ref{tab:zebra-results}): Similar improvements are seen.
    • Vanilla MDM (19M params): 76.9% accuracy.
    • MDM with Top probability: 98.5% accuracy.
    • MDM with Top-KK probability margin: 98.3% accuracy.
    • These outperform ARMs (42M params) with teacher-forced ordering (91.17%).

Easy-to-Hard Generalization (Sudoku):

MDMs with adaptive inference were tested on harder Sudoku puzzles not seen during training.

  • MDM (6M params) with Top-KK probability margin: 49.88% accuracy.
  • ARM (42M params) with teacher-forced ordering: 32.57% accuracy. This suggests MDMs with adaptive inference are more robust to distribution shifts in problem difficulty, possibly because their diverse training helps them extract more generalizable knowledge.

Conclusion

The paper concludes that the training complexity of MDMs, stemming from learning numerous hard subproblems, is a drawback compared to ARMs. However, this is offset by their inference-time flexibility. Adaptive inference strategies allow MDMs to intelligently choose the token decoding order, sidestepping the difficult subproblems encountered during training. This approach leads to significant performance gains, particularly in tasks requiring complex, sequence-dependent reasoning like logic puzzles, where MDMs can even outperform larger ARMs trained with explicit ordering information. Future work could explore more sophisticated adaptive strategies for broader applications.