The paper "Large-scale Reinforcement Learning for Diffusion Models" (Zhang et al., 20 Jan 2024 ) introduces a scalable reinforcement learning (RL) framework designed to fine-tune text-to-image diffusion models. The primary objective is to align these models with diverse objectives, including human preferences, compositionality accuracy, and fairness metrics, addressing limitations inherent in models pre-trained on large, uncurated web datasets. These limitations manifest as suboptimal sample quality, propagation of societal biases, and difficulties in accurately rendering complex textual descriptions.
Methodology: RL for Diffusion Model Alignment
The core methodology formulates the diffusion model's iterative denoising process as a finite-horizon Markov Decision Process (MDP). The goal is to optimize the policy, represented by the diffusion model's parameters , to maximize the expected terminal reward obtained from the generated image .
- State (): Defined by the tuple , where is the noisy image at timestep , is the conditioning text prompt, and is the current timestep.
- Action (): The predicted denoised image for the previous step, .
- Policy (): The parameterized reverse diffusion process .
- Reward (): A terminal reward function evaluated on the final generated image given the prompt . Intermediate rewards are zero.
The optimization employs the REINFORCE algorithm (specifically, the likelihood ratio gradient estimator) to update the model parameters . The objective function is to maximize the expected reward:
The policy gradient is given by:
To enhance training stability and scalability for large datasets (millions of prompts) and large models (like Stable Diffusion v2), several techniques are incorporated:
- Importance Sampling and Clipping: Similar to Proximal Policy Optimization (PPO), importance sampling is used to leverage samples generated by older policies. Policy clipping is applied to the likelihood ratio to prevent large, destabilizing updates. The clipped objective is used for the RL loss .
- Advantage Estimation: An advantage estimate is used instead of the raw reward . For this terminal reward setting, the advantage simplifies. Crucially, rewards are normalized per batch using the mean and variance of rewards within the current minibatch. This contrasts with prior work like DDPO which used per-prompt normalization, and is identified as a key factor for enabling large-scale training.
- Pretraining Loss Regularization: To prevent the model from over-optimizing on the reward function ("reward hacking") and maintain generative fidelity, the original diffusion model pretraining loss (denoising score matching loss, ) is added to the RL objective, weighted by a hyperparameter :
where is the ground truth noise and is the model's noise prediction.
Handling Distribution-based Rewards
A notable contribution is the method for incorporating reward functions that depend on the distribution of generated samples, rather than individual samples. This is essential for objectives like fairness or diversity (e.g., ensuring diverse skintone representation across generations for certain prompts). The paper proposes approximating the distribution-level reward by computing it empirically over the samples generated within each training minibatch. This minibatch statistic serves as the reward signal for the policy gradient update associated with that distribution-dependent objective. For instance, to promote skintone diversity, a statistical parity metric (negative difference from uniform distribution across skintone categories) is calculated over the images generated in a batch, and this single value is used as the reward for all samples in that batch contributing to the fairness objective.
Multi-task Joint Optimization
The framework supports simultaneous optimization for multiple reward functions. The multi-task training procedure (Algorithm 1 in the paper) involves:
- Sampling: In each training step, sample a batch of prompts associated with different tasks (e.g., aesthetic preference, fairness, compositionality).
- Generation: Generate images for each prompt using the current policy .
- Reward Calculation: Compute the relevant reward for each generated image based on its associated task.
- Gradient Updates: Sequentially compute and apply the policy gradient updates for each task, using the corresponding rewards and batch-wise normalization.
- Pretraining Loss Update: Compute and apply the gradient update for the pretraining loss .
This allows a single model to be trained to balance and achieve competence across multiple objectives.
Implementation Details and Scale
- Base Model: Stable Diffusion v2 (SDv2) with a 512x512 resolution UNet backbone.
- Training Scale: Experiments were conducted using 128 NVIDIA A100 GPUs. Datasets involved millions of prompts: 1.5M from DiffusionDB for preference, 240M BLIP-generated captions from Pinterest images for diversity (with race terms filtered), and over 1M synthetic prompts for compositionality.
- Reward Functions:
- Preference: ImageReward (IR), a trained model predicting human aesthetic preference.
- Fairness: A negative statistical parity score based on a 4-category skintone classifier (calculated per minibatch).
- Compositionality: Average detection confidence score from a UniDet object detector for objects mentioned in the prompt.
- Hyperparameters: The weight of the pretraining loss was tuned; values around 0.1 were often effective. The PPO clipping range was typically set to 0.2. AdamW optimizer was used.
Experimental Results
The proposed RL framework demonstrated significant improvements over the base SDv2 model and existing alignment methods across various tasks.
- Human Preference: The RL-tuned model was preferred by human evaluators 80.3% of the time against the base SDv2 model. It also achieved higher ImageReward scores and aesthetic ratings (on DiffusionDB and PartiPrompts) compared to baselines like ReFL, RAFT, DRaFT, and Reward-weighted Resampling. The RL approach appeared more robust to reward hacking than direct gradient methods like ReFL, which sometimes introduced high-frequency artifacts.
- Fairness (Skintone Diversity): Using the minibatch-based distribution reward, the model significantly reduced skintone bias compared to SDv2 on out-of-domain datasets (occupations, HRS-Bench), producing more equitable distributions across Fitzpatrick scale categories.
- Compositionality: The model fine-tuned with the UniDet reward showed improved ability to generate the correct objects specified in prompts, outperforming SDv2, particularly for prompts involving multiple objects and spatial relationships.
- Multi-task Learning: A single model jointly trained on preference, fairness, and compositionality rewards substantially outperformed the base SDv2 on all three metrics simultaneously. While specialized single-task models achieved peak performance on their respective metric, they often degraded performance on others (an "alignment tax"). The jointly trained model successfully mitigated this alignment tax, achieving over 80% of the relative performance gain of the specialized models across all objectives.
Practical Significance and Applications
This work provides a scalable and general framework for fine-tuning large diffusion models using RL. Its key practical implications include:
- Scalability: Demonstrates feasibility of RLHF-style alignment at the scale of millions of prompts, enabled by techniques like batch-wise reward normalization.
- Generality: Applicable to arbitrary, potentially non-differentiable reward functions (e.g., outputs of object detectors, human feedback simulators, fairness metrics) and distribution-level objectives.
- Multi-Objective Alignment: Offers a practical method for balancing multiple alignment criteria (e.g., aesthetics, safety, fairness, instruction following) within a single model, crucial for real-world deployment.
- Improved Robustness: The inclusion of the pretraining loss and the RL optimization process appears less susceptible to reward hacking compared to methods relying solely on direct reward gradient backpropagation.
- Deployment Strategy: Provides a post-hoc fine-tuning mechanism to adapt pre-trained foundation models to specific downstream requirements and ethical considerations without costly inference-time guidance or complete retraining.
This RL framework can be applied to tailor diffusion models for specific applications requiring high aesthetic quality, adherence to complex compositional instructions, or mitigation of known biases present in the original training data.
In conclusion, the paper presents a robust and scalable RL-based approach for aligning text-to-image diffusion models. By successfully incorporating diverse reward functions, including distribution-level metrics, and enabling effective multi-task optimization at scale, it offers a significant advancement in improving the controllability, quality, and fairness of generative image models.