Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
133 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
46 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Steering Masked Discrete Diffusion Models via Discrete Denoising Posterior Prediction (2410.08134v1)

Published 10 Oct 2024 in cs.LG and cs.AI

Abstract: Generative modeling of discrete data underlies important applications spanning text-based agents like ChatGPT to the design of the very building blocks of life in protein sequences. However, application domains need to exert control over the generated data by steering the generative process - typically via RLHF - to satisfy a specified property, reward, or affinity metric. In this paper, we study the problem of steering Masked Diffusion Models (MDMs), a recent class of discrete diffusion models that offer a compelling alternative to traditional autoregressive models. We introduce Discrete Denoising Posterior Prediction (DDPP), a novel framework that casts the task of steering pre-trained MDMs as a problem of probabilistic inference by learning to sample from a target Bayesian posterior. Our DDPP framework leads to a family of three novel objectives that are all simulation-free, and thus scalable while applying to general non-differentiable reward functions. Empirically, we instantiate DDPP by steering MDMs to perform class-conditional pixel-level image modeling, RLHF-based alignment of MDMs using text-based rewards, and finetuning protein LLMs to generate more diverse secondary structures and shorter proteins. We substantiate our designs via wet-lab validation, where we observe transient expression of reward-optimized protein sequences.

Summary

  • The paper introduces DDPP, a method that reframes steering MDMs as Bayesian posterior inference for controlled output generation.
  • It proposes three efficient DDPP variants (DDPP-IS, DDPP-LB, DDPP-KL) that optimize model alignment without costly forward simulation.
  • Empirical results across image, text, and protein domains validate DDPP’s scalability and competitiveness for discrete data tasks.

Discrete Denoising Posterior Prediction: Steering Masked Diffusion Models

The paper addresses the challenge of steering Masked Diffusion Models (MDMs) for discrete data generation, proposing a novel framework called Discrete Denoising Posterior Prediction (DDPP). This approach reframes the task of controlling MDM outputs as a probabilistic inference problem, where the goal is to sample from a Bayesian posterior distribution informed by pre-trained MDMs and specified reward models.

Background and Motivation

MDMs have emerged as a promising alternative to autoregressive models for discrete generative tasks, offering non-sequential data processing advantages. The ability to steer these models to meet specific objectives, such as RLHF scenarios, is crucial for practical applications. Traditional methods leveraging RLHF face challenges with MDMs as these models lack straightforward likelihood computations.

Methodology

The authors introduce DDPP, a framework that transforms the steering of MDMs into a task of sampling from a Bayesian posterior. The posterior is defined by the product of a pre-trained MDM's distribution and a reward model. This paper details a discrete denoising approach that leverages the structure of MDMs to approximate the posterior without engaging in expensive forward simulations.

Key Contributions

  1. Discrete Denoising Posterior Prediction (DDPP): The framework is structured to exploit the denoising capabilities of MDMs, aligning the model to a target Bayesian posterior. It establishes simplified objectives that are simulation-free, enabling scalable fine-tuning using non-differentiable rewards.
  2. DDPP Variants: Three strategies—DDPP-IS, DDPP-LB, and DDPP-KL—are proposed to estimate the partition function and optimize model alignment under different conditions, balancing computational efficiency and reward differentiation needs.
  3. Empirical Validation: The framework's effectiveness is demonstrated across diverse domains, including class-conditional image generation, LLMing for polarity, and protein sequence diversification, achieving competitive performance.

Theoretical Implications

The authors cast the steering task as an amortized sampler learning problem for Bayesian posteriors. This perspective aligns with classical RLHF techniques but extends their applicability to discrete diffusion models. They present a systematic approach to finetuning large MDMs, highlighting the importance of efficient partition function estimation.

Practical Implications

The approach opens new avenues for developing controlled generative models that can adhere to predefined specifications, optimizing tasks ranging from safe language generation to protein design. This work suggests that MDMs can be effectively utilized in domains that require non-sequential data modeling, thereby broadening their application scope.

Conclusion and Future Directions

The paper lays foundational work in the direct steering of MDMs via a structured posterior prediction approach. Future research could explore scaling this framework to larger models and further optimize DDPP strategies for different reward settings. Additionally, the exploration of more sophisticated inference techniques for improved sample quality and fidelity remains an open area.

The paper provides a comprehensive pathway for achieving steerable generative models, with theoretical and empirical insights significantly contributing to the domain of discrete data modeling.