- The paper introduces a calibration technique that subtracts the expected model output, ensuring the true data score is zero at every timestep.
- It reduces the score matching loss and tightens KL divergence bounds, thereby enhancing the model’s evidence lower bound.
- Experimental results show improved sample quality and faster convergence, yielding better FID scores on benchmarks like CIFAR-10 and CelebA.
This paper introduces a method to calibrate pretrained Diffusion Probabilistic Models (DPMs) to improve their performance, specifically focusing on sample quality and model likelihood (2302.10688). The core idea stems from the observation that the scaled true data score, αt∇xtlogqt(xt), follows a martingale process in reverse time. A key property derived from this is that the expected value of the true data score should be zero at any timestep t: Eqt(xt)[∇xtlogqt(xt)]=0.
However, empirically trained DPMs, represented by a learned score model sθt(xt) (or equivalent parametrizations like noise prediction ϵθt(xt)), often violate this property, meaning Eqt(xt)[sθt(xt)]=0. This "mis-calibration" can arise from training on finite datasets and sub-optimal model convergence. The paper proposes a simple calibration technique: subtract this non-zero expectation from the model's output during inference.
Calibration Method
The proposed calibration adjusts the score model output as follows:
sθ,calibratedt(xt)=sθt(xt)−ηt∗
where ηt∗ is the optimal calibration term for timestep t:
ηt∗=Eqt(xt)[sθt(xt)]
This calibration offers several benefits:
- Reduced Score Matching Loss: Calibration provably reduces the Score Matching (SM) objective at each timestep t:
JSMt(θ,ηt∗)=JSMt(θ)−21∥Eqt(xt)[sθt(xt)]∥22
Similar reductions apply to the Denoising Score Matching (DSM) objective and equivalent losses under other model parametrizations.
- Improved Likelihood Bounds: Since SM objectives are linked to the KL divergence between the data distribution q0 and the model distribution p0 (for both SDE and ODE formulations), reducing the SM loss leads to tighter (lower) upper bounds on DKL(q0∥p0), effectively increasing the evidence lower bound (ELBO) of the model likelihood.
- Improved Sample Quality: Experiments show that calibration leads to better FID scores on datasets like CIFAR-10 and CelebA 64x64, especially when using efficient ODE solvers like DPM-Solver with few steps (Number of Function Evaluations - NFE). It can also lead to faster convergence during sampling and reduce ambiguous generations.
Implementation Strategies
The main practical challenge is estimating the calibration term ηt∗=Eqt(xt)[sθt(xt)] (or its equivalent for other parametrizations). The paper suggests three approaches:
- Post-hoc Estimation using Training Data:
- This is the primary method used in the experiments.
- Requires access to the original training dataset {x0n}n=1N.
- For each required timestep t, estimate the expectation using Monte Carlo sampling:
- Sample mini-batches of x0 from the training set.
- Sample noise ϵ∼N(0,I).
- Compute noisy samples xt=αtx0+σtϵ.
- Pass xt through the pretrained model sθt(xt).
- Average the model outputs sθt(xt) over a large number of samples to approximate ηt∗.
- Store the computed ηt∗ values (e.g., in a dictionary mapping t to the vector ηt∗) for each timestep required by the sampler.
- During inference, retrieve the appropriate ηt∗ and subtract it from the model's output.
- Consideration: Requires potentially significant computation time upfront (once per model) and access to training data. Ablations show that using a large fraction of the training data (e.g., >20k images for CIFAR-10) might be necessary for accurate estimation.
Pseudocode for Post-hoc Estimation (Noise Prediction Parametrization):
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
|
# Pretrained model: epsilon_theta(xt, t)
# Training data loader: train_loader
# Timesteps needed by sampler: timesteps_to_calibrate
# Number of samples for MC estimation: num_mc_samples
calibration_terms = {}
model.eval()
with torch.no_grad():
for t in timesteps_to_calibrate:
eta_t_sum = torch.zeros_like(output_shape) # e.g., [C, H, W]
samples_processed = 0
while samples_processed < num_mc_samples:
for x0_batch in train_loader:
batch_size = x0_batch.size(0)
epsilon = torch.randn_like(x0_batch)
alpha_t, sigma_t = get_alpha_sigma(t) # Get diffusion schedule params
xt_batch = alpha_t * x0_batch + sigma_t * epsilon
predicted_noise_batch = model(xt_batch, t)
eta_t_sum += predicted_noise_batch.sum(dim=0)
samples_processed += batch_size
if samples_processed >= num_mc_samples:
break
eta_t_avg = eta_t_sum / samples_processed
calibration_terms[t] = eta_t_avg
# During sampling:
# predicted_noise = model(xt, t)
# calibrated_noise = predicted_noise - calibration_terms[t]
# Use calibrated_noise in the solver step |
- Post-hoc Estimation using Generated Data:
- Use this if training data is unavailable.
- First, generate a large set of samples {x~0} using the original pretrained DPM.
- Then, use these generated samples x~0 in place of the training data in the Monte Carlo estimation process described above.
- Consideration: The quality of calibration depends on the quality and diversity of the generated samples. Experiments suggest it can achieve results comparable to using training data if enough high-quality samples are generated (e.g., 20k for CIFAR-10), but using too many might not help further.
- Dynamic Recording during Training:
- Estimate ηt∗ during the DPM pretraining phase.
- Introduce a small auxiliary network hϕ(t) (e.g., an MLP) that takes timestep t as input.
- Train hϕ(t) to predict the expected model output Eqt(xt)[sθt(xt)] using mini-batch averages and stop-gradients on sθt(xt). The loss for hϕ(t) at step t is JCalt(ϕ)=Eqt(xt)[∥hϕ(t)−sθt(xt)†∥22].
- At inference time, use the output of the trained hϕ(t) as the calibration term ηt∗.
- Consideration: Avoids post-hoc computation and data access issues but adds complexity to the training pipeline. Allows immediate generation after training.
Model Parametrizations
The paper provides calibration formulas for common DPM parametrizations:
Parametrization |
Model Output |
Calibration Term ηt∗ |
Calibrated Output |
Score Prediction |
sθt(xt) |
Eqt(xt)[sθt(xt)] |
sθt(xt)−ηt∗ |
Noise Prediction (Common) |
ϵθt(xt) |
Eqt(xt)[ϵθt(xt)] |
ϵθt(xt)−ηt∗ |
Data Prediction |
xθt(xt) |
Eqt(xt)[xθt(xt)]−Eq0(x0)[x0] |
xθt(xt)−(ηt∗+Eq0(x0)[x0]) |
Velocity Prediction |
vθt(xt) |
Eqt(xt)[vθt(xt)]+σtEq0(x0)[x0] |
vθt(xt)−(ηt∗−σtEq0(x0)[x0]) |
(Note: For data/velocity prediction, Eq0(x0)[x0] is the mean of the training data, often assumed to be 0 after normalization).
Conditional Models
The calibration principle extends to conditional DPMs sθt(xt,y). The calibration term becomes condition-specific: ηt∗(y)=Eqt(xt∣y)[sθt(xt,y)].
- Consideration: Post-hoc estimation is challenging if the condition space y is large or continuous (e.g., text prompts), as ηt∗(y) needs to be computed for potentially infinite y. Dynamic recording might be more practical in such cases, potentially using multimodal modules for the auxiliary network hϕ(t,y).
Limitations
- While calibration improves likelihood bounds and often FID, improvements in perceptual metrics like FID are not strictly guaranteed by theory.
- Post-hoc computation for conditional models with large condition spaces is impractical.
In summary, the paper presents a theoretically grounded and practically effective method to calibrate any pretrained DPM by subtracting the expected model output at each timestep. This one-time (per model) calibration procedure can improve sample quality, accelerate sampling convergence, and increase model likelihood bounds, requiring estimation of the calibration term via training data, generated data, or dynamic recording during training.