Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
119 tokens/sec
GPT-4o
56 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Simplified and Generalized Masked Diffusion for Discrete Data (2406.04329v2)

Published 6 Jun 2024 in cs.LG and stat.ML

Abstract: Masked (or absorbing) diffusion is actively explored as an alternative to autoregressive models for generative modeling of discrete data. However, existing work in this area has been hindered by unnecessarily complex model formulations and unclear relationships between different perspectives, leading to suboptimal parameterization, training objectives, and ad hoc adjustments to counteract these issues. In this work, we aim to provide a simple and general framework that unlocks the full potential of masked diffusion models. We show that the continuous-time variational objective of masked diffusion models is a simple weighted integral of cross-entropy losses. Our framework also enables training generalized masked diffusion models with state-dependent masking schedules. When evaluated by perplexity, our models trained on OpenWebText surpass prior diffusion LLMs at GPT-2 scale and demonstrate superior performance on 4 out of 5 zero-shot LLMing tasks. Furthermore, our models vastly outperform previous discrete diffusion models on pixel-level image modeling, achieving 2.75 (CIFAR-10) and 3.40 (ImageNet 64x64) bits per dimension that are better than autoregressive models of similar sizes.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (5)
  1. Jiaxin Shi (53 papers)
  2. Kehang Han (6 papers)
  3. Zhe Wang (574 papers)
  4. Arnaud Doucet (161 papers)
  5. Michalis K. Titsias (39 papers)
Citations (22)

Summary

Simplified and Generalized Masked Diffusion for Discrete Data

The paper "Simplified and Generalized Masked Diffusion for Discrete Data" by Jiaxin Shi, Kehang Han, Zhe Wang, Arnaud Doucet, and Michalis K. Titsias introduces a unified framework for masked (or absorbing) diffusion models aimed at generative modeling of discrete data. The authors address several limitations in existing masked diffusion models, such as complex formulations and suboptimal training objectives, and present a more straightforward approach that improves performance significantly.

Key Contributions

  1. Simplified Theoretical Framework: The authors show that the continuous-time variational objective for masked diffusion models can be represented as a simple weighted integral of cross-entropy losses. This formulation unifies various approaches from the literature and clarifies the relationships between them.
  2. Generalized Model with State-Dependent Masking Schedules: The paper extends the standard formulation by incorporating state-dependent masking schedules, allowing the model to prioritize the masking and unmasking of specific tokens based on their states. This generalization enhances the model's flexibility and performance.
  3. Improved Parameterization and Training Objectives: Leveraging a prediction model for the mean (mean-parameterization) of the clean data given the masked data, the authors argue that this achieves more stable and effective training compared to score-based parameterizations used in prior work. The resulting models outperform previous diffusion models on standard benchmarks.
  4. Empirical Results: The models trained using the proposed framework achieve superior likelihood and zero-shot transfer performance on text modeling tasks. Specifically, the models exhibit better perplexity on OpenWebText and strong performance on several zero-shot LLMing tasks compared to existing diffusion models and GPT-2.
  5. Application to Image Data: The paper demonstrates the efficacy of the proposed framework in pixel-level image modeling tasks on datasets like CIFAR-10 and Downsampled ImageNet 64x64. The new models significantly outperform existing discrete diffusion models and match or exceed the performance of autoregressive models of similar size.

Experimental Evaluation

Text Modeling

For text modeling, the authors train their models on OpenWebText and evaluate them on tasks such as LAMBADA, WikiText2, and Penn Treebank. The results show that their masked diffusion models (referred to as MD4 and GenMD4) outperform previous methods like D3PM and SEDD Absorb in terms of zero-shot perplexity. The models also demonstrate faster convergence and better final likelihoods on the validation set.

On the text8 dataset, the MD4 and GenMD4 models achieve lower bits-per-character (BPC) than previous state-of-the-art diffusion models and any-order autoregressive models. The GenMD4 model further improves BPC, showcasing the benefits of state-dependent masking schedules.

Image Modeling

In pixel-level image modeling, MD4 sets a new state-of-the-art for discrete diffusion models on CIFAR-10 and matches the performance of autoregressive models on ImageNet 64x64. The paper includes several samples generated by MD4, demonstrating high-quality image synthesis despite modeling pixels as discrete tokens.

Theoretical Insights

The authors provide several theoretical results that enhance understanding and training of masked diffusion models. They derive the continuous-time limit of the Evidence Lower Bound (ELBO) for masked diffusion models and show its invariance properties concerning noise schedules. They also establish connections to existing work in continuous-time Markov chains and alternative parameterization approaches.

Future Directions

The paper concludes by suggesting future research directions, including the development of better architectures for discrete diffusion models and more robust state-dependent masking schedules. Additionally, the authors highlight the potential for extending their framework to other domains beyond text and image data.

Conclusion

The proposed framework for simplified and generalized masked diffusion models represents a significant advancement in generative modeling of discrete data. By addressing complexities in existing models and introducing state-dependent masking schedules, the authors achieve substantial improvements in both theoretical formulation and empirical performance. This work provides a solid foundation for future research in discrete diffusion models and their applications.

Youtube Logo Streamline Icon: https://streamlinehq.com