- The paper introduces the ReMDM sampler to enable iterative remasking of tokens during inference, improving sample quality in discrete diffusion models.
- It employs remasking strategies with flexible scheduling to bridge the gap toward autoregressive performance in applications like text and image generation.
- Empirical results on datasets such as OpenWebText and ImageNet demonstrate ReMDM’s efficacy with improved MAUVE scores, FID, and IS metrics.
Remasking Discrete Diffusion Models with Inference-Time Scaling
This paper presents "Remasking Discrete Diffusion Models with Inference-Time Scaling" (2503.00307), addressing a critical weakness in established masked discrete diffusion models by introducing a novel remasking diffusion model (ReMDM) sampler. This approach endows discrete diffusion models with the ability for iterative refinement by remasking generated tokens during the inference phase and enhances sample quality, particularly in NLP and molecule design tasks.
Introduction
Discrete diffusion models have been gaining attention for their ability to generate high-quality outputs through iterative refinement during the generation process, particularly in image and biological sequence generation tasks. Masked diffusion models have achieved state-of-the-art results in discrete data modeling, but lack the iterative refinement capability, limiting controllable generation and sample quality. The paper introduces the ReMDM sampler, which rectifies this issue by enabling remasking during generation, allowing for inference-time computation scaling. This novel approach permits the iterative refinement of generated tokens, improving quality and pushing closer to autoregressive (AR) model performance.
Figure 1: Our family of masked diffusion processes allow for more flexible generation with remasking of already decoded tokens. This improves sample quality and further closes the gap to AR models. (Left) An illustrative example of errors fixed by ReMDM. (Right) MAUVE scores on OpenWebText.
Theoretical Framework
Discrete Diffusion Models
The ReMDM sampler is rooted in a probabilistic model, characterized by a remasking backward process. The diffusion models follow parametric models pθ​ that invert a specific noising process q. Discrete extensions have been proposed for applications like LLMing, yet they traditionally suffer from the failure to remask, a problem ReMDM addresses.
Remasking Diffusion Sampler Derivation
The ReMDM sampler introduces a discrete diffusion model characterizing a remasking backward process, capable of remasking decoded tokens with a user-specified probability at each time step. This approach aligns with ancestral sampling in probabilistic modeling, allowing enhanced sample quality and inference-time compute scaling.
Experimental Evaluation
Text Generation on OpenWebText
An evaluation of ReMDM's performance was conducted using pretrained models on the OpenWebText (OWT) dataset. ReMDM was compared against several baselines, including autoregressive (AR), SEDD, MDLM, FB, and DFM samplers.
Results showcased ReMDM's favorable sample quality scaling with inference-time compute, with MAUVE scores almost matching AR models.
Figure 2: MAUVE scores of ReMDM inference-time scaling.
Image Generation
A MaskGiT model trained on ImageNet was adopted for image generation tasks. The ReMDM sampler demonstrated superior image quality at extended sampling steps, characterized by better Frechet Inception Distance (FID) and Inception Score (IS) across various conditions of T.
Figure 3: Effect of ReMDM components on OWT generation quality. Inference-time scaling with T∈{1024,2048,4096}.
Remasking Strategies
The implementation details of the ReMDM sampler include various strategies to set remasking probabilities (σt​), including max-capped, rescaled, and confidence-based schedules. These scheduling strategies significantly impact performance, with -loop and -cap schedules performing best in different settings, thereby improving flexibility and performance.
ReMDM shares conceptual similarities with DDIM for continuous diffusion models, enhancing masked diffusion with a remasking strategy akin to DDIM's non-Markovian approaches. Further comparisons reveal theoretical and empirical advantages over existing discrete predictor-corrector methods like DFM, offering new directions in discrete diffusion frameworks.
Conclusion
The introduction of the ReMDM sampler marks a notable advancement in masked diffusion models, enhancing their flexibility and sample quality through remasking capabilities. By almost matching the effectiveness of AR models and extending their applicability to diverse generative tasks, the approach sets a new standard for discrete diffusion models.