Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
162 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
45 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

On Calibrating Diffusion Probabilistic Models (2302.10688v3)

Published 21 Feb 2023 in cs.LG, cs.CV, and stat.ML

Abstract: Recently, diffusion probabilistic models (DPMs) have achieved promising results in diverse generative tasks. A typical DPM framework includes a forward process that gradually diffuses the data distribution and a reverse process that recovers the data distribution from time-dependent data scores. In this work, we observe that the stochastic reverse process of data scores is a martingale, from which concentration bounds and the optional stopping theorem for data scores can be derived. Then, we discover a simple way for calibrating an arbitrary pretrained DPM, with which the score matching loss can be reduced and the lower bounds of model likelihood can consequently be increased. We provide general calibration guidelines under various model parametrizations. Our calibration method is performed only once and the resulting models can be used repeatedly for sampling. We conduct experiments on multiple datasets to empirically validate our proposal. Our code is at https://github.com/thudzj/Calibrated-DPMs.

Citations (1)

Summary

  • The paper introduces a calibration technique that subtracts the expected model output, ensuring the true data score is zero at every timestep.
  • It reduces the score matching loss and tightens KL divergence bounds, thereby enhancing the model’s evidence lower bound.
  • Experimental results show improved sample quality and faster convergence, yielding better FID scores on benchmarks like CIFAR-10 and CelebA.

This paper introduces a method to calibrate pretrained Diffusion Probabilistic Models (DPMs) to improve their performance, specifically focusing on sample quality and model likelihood (2302.10688). The core idea stems from the observation that the scaled true data score, αtxtlogqt(xt)\alpha_t \nabla_{x_t \log q_t(x_t)}, follows a martingale process in reverse time. A key property derived from this is that the expected value of the true data score should be zero at any timestep tt: Eqt(xt)[xtlogqt(xt)]=0\mathbb{E}_{q_t(x_t)}[\nabla_{x_t \log q_t(x_t)}] = 0.

However, empirically trained DPMs, represented by a learned score model sθt(xt)\boldsymbol{s}^t_\theta(x_t) (or equivalent parametrizations like noise prediction ϵθt(xt)\boldsymbol{\epsilon}^t_\theta(x_t)), often violate this property, meaning Eqt(xt)[sθt(xt)]0\mathbb{E}_{q_t(x_t)}[\boldsymbol{s}^t_\theta(x_t)] \neq 0. This "mis-calibration" can arise from training on finite datasets and sub-optimal model convergence. The paper proposes a simple calibration technique: subtract this non-zero expectation from the model's output during inference.

Calibration Method

The proposed calibration adjusts the score model output as follows:

sθ,calibratedt(xt)=sθt(xt)ηt\boldsymbol{s}^{t}_{\theta, \text{calibrated}}(x_t) = \boldsymbol{s}^{t}_{\theta}(x_t) - \eta_t^*

where ηt\eta_t^* is the optimal calibration term for timestep tt:

ηt=Eqt(xt)[sθt(xt)]\eta_t^* = \mathbb{E}_{q_t(x_t)}[\boldsymbol{s}^t_\theta(x_t)]

This calibration offers several benefits:

  1. Reduced Score Matching Loss: Calibration provably reduces the Score Matching (SM) objective at each timestep tt:

    JSMt(θ,ηt)=JSMt(θ)12Eqt(xt)[sθt(xt)]22\mathcal{J}_{\textrm{SM}^t}(\theta, \eta_t^*) = \mathcal{J}_{\textrm{SM}^t}(\theta) - \frac{1}{2} \|\mathbb{E}_{q_t(x_t)}[\boldsymbol{s}^t_\theta(x_t)]\|_2^2

    Similar reductions apply to the Denoising Score Matching (DSM) objective and equivalent losses under other model parametrizations.

  2. Improved Likelihood Bounds: Since SM objectives are linked to the KL divergence between the data distribution q0q_0 and the model distribution p0p_0 (for both SDE and ODE formulations), reducing the SM loss leads to tighter (lower) upper bounds on DKL(q0p0)\mathcal{D}_{\text{KL}}(q_0 \| p_0), effectively increasing the evidence lower bound (ELBO) of the model likelihood.
  3. Improved Sample Quality: Experiments show that calibration leads to better FID scores on datasets like CIFAR-10 and CelebA 64x64, especially when using efficient ODE solvers like DPM-Solver with few steps (Number of Function Evaluations - NFE). It can also lead to faster convergence during sampling and reduce ambiguous generations.

Implementation Strategies

The main practical challenge is estimating the calibration term ηt=Eqt(xt)[sθt(xt)]\eta_t^* = \mathbb{E}_{q_t(x_t)}[\boldsymbol{s}^t_\theta(x_t)] (or its equivalent for other parametrizations). The paper suggests three approaches:

  1. Post-hoc Estimation using Training Data:
    • This is the primary method used in the experiments.
    • Requires access to the original training dataset {x0n}n=1N\{x_0^n\}_{n=1}^N.
    • For each required timestep tt, estimate the expectation using Monte Carlo sampling:
      • Sample mini-batches of x0x_0 from the training set.
      • Sample noise ϵN(0,I)\epsilon \sim \mathcal{N}(0, I).
      • Compute noisy samples xt=αtx0+σtϵx_t = \alpha_t x_0 + \sigma_t \epsilon.
      • Pass xtx_t through the pretrained model sθt(xt)\boldsymbol{s}^t_\theta(x_t).
      • Average the model outputs sθt(xt)\boldsymbol{s}^t_\theta(x_t) over a large number of samples to approximate ηt\eta_t^*.
    • Store the computed ηt\eta_t^* values (e.g., in a dictionary mapping tt to the vector ηt\eta_t^*) for each timestep required by the sampler.
    • During inference, retrieve the appropriate ηt\eta_t^* and subtract it from the model's output.
    • Consideration: Requires potentially significant computation time upfront (once per model) and access to training data. Ablations show that using a large fraction of the training data (e.g., >20k images for CIFAR-10) might be necessary for accurate estimation.

    Pseudocode for Post-hoc Estimation (Noise Prediction Parametrization):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    
    # Pretrained model: epsilon_theta(xt, t)
    # Training data loader: train_loader
    # Timesteps needed by sampler: timesteps_to_calibrate
    # Number of samples for MC estimation: num_mc_samples
    
    calibration_terms = {}
    model.eval()
    with torch.no_grad():
        for t in timesteps_to_calibrate:
            eta_t_sum = torch.zeros_like(output_shape) # e.g., [C, H, W]
            samples_processed = 0
            while samples_processed < num_mc_samples:
                for x0_batch in train_loader:
                    batch_size = x0_batch.size(0)
                    epsilon = torch.randn_like(x0_batch)
                    alpha_t, sigma_t = get_alpha_sigma(t) # Get diffusion schedule params
                    xt_batch = alpha_t * x0_batch + sigma_t * epsilon
    
                    predicted_noise_batch = model(xt_batch, t)
                    eta_t_sum += predicted_noise_batch.sum(dim=0)
                    samples_processed += batch_size
                    if samples_processed >= num_mc_samples:
                        break
            
            eta_t_avg = eta_t_sum / samples_processed
            calibration_terms[t] = eta_t_avg
    
    # During sampling:
    # predicted_noise = model(xt, t)
    # calibrated_noise = predicted_noise - calibration_terms[t] 
    # Use calibrated_noise in the solver step

  2. Post-hoc Estimation using Generated Data:
    • Use this if training data is unavailable.
    • First, generate a large set of samples {x~0}\{\tilde{x}_0\} using the original pretrained DPM.
    • Then, use these generated samples x~0\tilde{x}_0 in place of the training data in the Monte Carlo estimation process described above.
    • Consideration: The quality of calibration depends on the quality and diversity of the generated samples. Experiments suggest it can achieve results comparable to using training data if enough high-quality samples are generated (e.g., 20k for CIFAR-10), but using too many might not help further.
  3. Dynamic Recording during Training:
    • Estimate ηt\eta_t^* during the DPM pretraining phase.
    • Introduce a small auxiliary network hϕ(t)h_\phi(t) (e.g., an MLP) that takes timestep tt as input.
    • Train hϕ(t)h_\phi(t) to predict the expected model output Eqt(xt)[sθt(xt)]\mathbb{E}_{q_t(x_t)}[\boldsymbol{s}^t_\theta(x_t)] using mini-batch averages and stop-gradients on sθt(xt)\boldsymbol{s}^t_\theta(x_t). The loss for hϕ(t)h_\phi(t) at step tt is JCalt(ϕ)=Eqt(xt)[hϕ(t)sθt(xt)22]\mathcal{J}_{\text{Cal}^t}(\phi) = \mathbb{E}_{q_t(x_t)}[\|h_\phi(t) - \boldsymbol{s}^t_\theta(x_t)^\dagger\|_2^2].
    • At inference time, use the output of the trained hϕ(t)h_\phi(t) as the calibration term ηt\eta_t^*.
    • Consideration: Avoids post-hoc computation and data access issues but adds complexity to the training pipeline. Allows immediate generation after training.

Model Parametrizations

The paper provides calibration formulas for common DPM parametrizations:

Parametrization Model Output Calibration Term ηt\eta_t^* Calibrated Output
Score Prediction sθt(xt)\boldsymbol{s}^t_\theta(x_t) Eqt(xt)[sθt(xt)]\mathbb{E}_{q_t(x_t)}[\boldsymbol{s}^t_\theta(x_t)] sθt(xt)ηt\boldsymbol{s}^t_\theta(x_t) - \eta_t^*
Noise Prediction (Common) ϵθt(xt)\boldsymbol{\epsilon}^t_\theta(x_t) Eqt(xt)[ϵθt(xt)]\mathbb{E}_{q_t(x_t)}[\boldsymbol{\epsilon}^t_\theta(x_t)] ϵθt(xt)ηt\boldsymbol{\epsilon}^t_\theta(x_t) - \eta_t^*
Data Prediction xθt(xt)\boldsymbol{x}^t_\theta(x_t) Eqt(xt)[xθt(xt)]Eq0(x0)[x0]\mathbb{E}_{q_t(x_t)}[\boldsymbol{x}^t_\theta(x_t)] - \mathbb{E}_{q_0(x_0)}[x_0] xθt(xt)(ηt+Eq0(x0)[x0])\boldsymbol{x}^t_\theta(x_t) - (\eta_t^* + \mathbb{E}_{q_0(x_0)}[x_0])
Velocity Prediction vθt(xt)\boldsymbol{v}^t_\theta(x_t) Eqt(xt)[vθt(xt)]+σtEq0(x0)[x0]\mathbb{E}_{q_t(x_t)}[\boldsymbol{v}^t_\theta(x_t)] + \sigma_t \mathbb{E}_{q_0(x_0)}[x_0] vθt(xt)(ηtσtEq0(x0)[x0])\boldsymbol{v}^t_\theta(x_t) - (\eta_t^* - \sigma_t \mathbb{E}_{q_0(x_0)}[x_0])

(Note: For data/velocity prediction, Eq0(x0)[x0]\mathbb{E}_{q_0(x_0)}[x_0] is the mean of the training data, often assumed to be 0 after normalization).

Conditional Models

The calibration principle extends to conditional DPMs sθt(xt,y)\boldsymbol{s}^t_\theta(x_t, y). The calibration term becomes condition-specific: ηt(y)=Eqt(xty)[sθt(xt,y)]\eta_t^*(y) = \mathbb{E}_{q_t(x_t|y)}[\boldsymbol{s}^t_\theta(x_t, y)].

  • Consideration: Post-hoc estimation is challenging if the condition space yy is large or continuous (e.g., text prompts), as ηt(y)\eta_t^*(y) needs to be computed for potentially infinite yy. Dynamic recording might be more practical in such cases, potentially using multimodal modules for the auxiliary network hϕ(t,y)h_\phi(t, y).

Limitations

  • While calibration improves likelihood bounds and often FID, improvements in perceptual metrics like FID are not strictly guaranteed by theory.
  • Post-hoc computation for conditional models with large condition spaces is impractical.

In summary, the paper presents a theoretically grounded and practically effective method to calibrate any pretrained DPM by subtracting the expected model output at each timestep. This one-time (per model) calibration procedure can improve sample quality, accelerate sampling convergence, and increase model likelihood bounds, requiring estimation of the calibration term via training data, generated data, or dynamic recording during training.

Github Logo Streamline Icon: https://streamlinehq.com