Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
110 tokens/sec
GPT-4o
56 tokens/sec
Gemini 2.5 Pro Pro
44 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Consistency Models (2303.01469v2)

Published 2 Mar 2023 in cs.LG, cs.CV, and stat.ML

Abstract: Diffusion models have significantly advanced the fields of image, audio, and video generation, but they depend on an iterative sampling process that causes slow generation. To overcome this limitation, we propose consistency models, a new family of models that generate high quality samples by directly mapping noise to data. They support fast one-step generation by design, while still allowing multistep sampling to trade compute for sample quality. They also support zero-shot data editing, such as image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either by distilling pre-trained diffusion models, or as standalone generative models altogether. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step sampling, achieving the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained in isolation, consistency models become a new family of generative models that can outperform existing one-step, non-adversarial generative models on standard benchmarks such as CIFAR-10, ImageNet 64x64 and LSUN 256x256.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (4)
  1. Yang Song (299 papers)
  2. Prafulla Dhariwal (15 papers)
  3. Mark Chen (15 papers)
  4. Ilya Sutskever (58 papers)
Citations (645)

Summary

  • The paper introduces Consistency Models, which bypass iterative diffusion sampling by learning a function that maps any point on the PF ODE trajectory directly back to the original data.
  • It details two training methods—Consistency Distillation and Consistency Training—that enforce self-consistency using parameterized functions and EMA-based target networks.
  • Practical applications include fast one-step generation and versatile image editing tasks, achieving state-of-the-art FID scores on datasets like CIFAR-10 and ImageNet.

This paper introduces Consistency Models (CMs) (Song et al., 2023 ), a new class of generative models designed to address the slow iterative sampling process inherent in diffusion models while retaining many of their benefits. The core idea is to learn a function that directly maps points from any time step on a Probability Flow (PF) ODE trajectory back to the trajectory's origin (the data sample).

Core Concept: Consistency Function

  1. PF ODE: Diffusion models rely on a PF ODE (Eq. 2) that transforms data x0x_0 into noise xTx_T. The reverse process generates data from noise xTx_T by solving the ODE backward.
  2. Consistency Function: Defined as f:(xt,t)xϵf: (x_t, t) \mapsto x_\epsilon, where xtx_t is a point on the ODE trajectory at time tt, and xϵx_\epsilon is the point near the origin (data point, typically at a small time ϵ>0\epsilon > 0).
  3. Self-Consistency Property: The defining characteristic is that for any two points (xt,t)(x_t, t) and (xt,t)(x_{t'}, t') on the same ODE trajectory, the consistency function output is identical: f(xt,t)=f(xt,t)=xϵf(x_t, t) = f(x_{t'}, t') = x_\epsilon.
  4. Consistency Model: A parameterized function Fθ(x,t)F_\theta(x, t) is trained to approximate the true consistency function ff by enforcing this self-consistency property.

Implementation: Parameterization

A crucial aspect is enforcing the boundary condition Fθ(x,ϵ)=xF_\theta(x, \epsilon) = x. The paper proposes and uses a practical parameterization using skip connections:

1
F_\theta(x, t) = c_skip(t) * x + c_out(t) * NN_\theta(x, t)
where NNθ(x,t)NN_\theta(x, t) is a neural network (e.g., based on diffusion model architectures like U-Net), and cskip(t)c_\text{skip}(t), cout(t)c_\text{out}(t) are differentiable functions satisfying:

  • cskip(ϵ)=1c_\text{skip}(\epsilon) = 1
  • cout(ϵ)=0c_\text{out}(\epsilon) = 0

This structure ensures the boundary condition is met and allows leveraging existing diffusion model architectures. The paper uses modified versions of the scaling factors from EDM (Karras et al., 2022 ) to satisfy this for ϵ>0\epsilon > 0.

Implementation: Sampling

  • One-Step Generation: Sample noise xTN(0,T2I)x_T \sim N(0, T^2I) and compute the data sample directly: x^ϵ=Fθ(xT,T)\hat{x}_\epsilon = F_\theta(x_T, T). This is very fast, requiring only one network evaluation.
  • Multi-Step Sampling (Algorithm 1): Improves sample quality by trading compute. It involves alternating denoising steps with the CM and adding noise:
    • Select a time τn\tau_n (from a predefined sequence T>τ1>...>τN1>ϵT > \tau_1 > ... > \tau_{N-1} > \epsilon).
    • Add noise: Sample zN(0,I)z \sim N(0, I), compute xτn=x^(n1)+τn2ϵ2zx_{\tau_n} = \hat{x}^{(n-1)} + \sqrt{\tau_n^2 - \epsilon^2} z.
    • Denoise: Compute x^(n)=Fθ(xτn,τn)\hat{x}^{(n)} = F_\theta(x_{\tau_n}, \tau_n).
    • 3. Output x^(N1)\hat{x}^{(N-1)}.
    • The sequence {τn}\{\tau_n\} can be found using optimization methods like greedy ternary search to minimize FID.

Training Method 1: Consistency Distillation (CD)

This method trains a CM FθF_\theta by distilling knowledge from a pre-trained diffusion (score) model sϕs_\phi.

  1. Goal: Enforce Fθ(xtn+1,tn+1)Fθ(x^tnϕ,tn)F_\theta(x_{t_{n+1}}, t_{n+1}) \approx F_{\theta^-}(\hat{x}_{t_n}^\phi, t_n) for adjacent points on the empirical PF ODE trajectory defined by sϕs_\phi.
  2. Process (Algorithm 2):
    • Sample data x0x_0.
    • Sample time index nU{1,...,N1}n \sim U\{1, ..., N-1\}.
    • Generate noisy sample xtn+1N(x0,tn+12I)x_{t_{n+1}} \sim N(x_0, t_{n+1}^2 I).
    • Use one step of a numerical ODE solver (e.g., Heun) with the score model sϕs_\phi to estimate the previous point: x^tnϕ=xtn+1+(tntn+1)Φ(xtn+1,tn+1;sϕ)\hat{x}_{t_n}^\phi = x_{t_{n+1}} + (t_n - t_{n+1})\Phi(x_{t_{n+1}}, t_{n+1}; s_\phi).
    • Minimize the consistency distillation loss (Eq. 7):

      LCDN=E[λ(tn)d(Fθ(xtn+1,tn+1),Fθ(x^tnϕ,tn))]L_\text{CD}^N = \mathbb{E}[\lambda(t_n) d(F_\theta(x_{t_{n+1}}, t_{n+1}), F_{\theta^-}(\hat{x}_{t_n}^\phi, t_n))]

  3. Implementation Details:
    • FθF_{\theta^-} is a target network, updated via Exponential Moving Average (EMA) of FθF_\theta (Eq. 8). Using stop_gradient on the target network output is crucial for stability.
    • d(,)d(\cdot, \cdot) is a distance metric. LPIPS (Zhang et al., 2018 ) works best for images, outperforming L1 and L2.
    • λ(tn)\lambda(t_n) is a weighting function (often set to 1).
    • Higher-order ODE solvers (like Heun) generally perform better than lower-order ones (like Euler) for computing x^tnϕ\hat{x}_{t_n}^\phi.
    • The number of discretization intervals NN needs tuning (e.g., N=18N=18 for CIFAR-10 with Heun).

Training Method 2: Consistency Training (CT)

This method trains a CM FθF_\theta from scratch, without requiring a pre-trained diffusion model. It makes CMs an independent class of generative models.

  1. Goal: Enforce Fθ(x0+tn+1z,tn+1)Fθ(x0+tnz,tn)F_\theta(x_0 + t_{n+1}z, t_{n+1}) \approx F_{\theta^-}(x_0 + t_n z, t_n), where zN(0,I)z \sim N(0, I).
  2. Process (Algorithm 3): Based on the theoretical result (Theorem 2) that the CD loss approximates the CT loss (Eq. 9) for small step sizes when using Euler solver implicitly.
    • Sample data x0x_0.
    • Sample time index nU{1,...,N(k)1}n \sim U\{1, ..., N(k)-1\} (where N(k)N(k) increases during training).
    • Sample noise zN(0,I)z \sim N(0, I).
    • Minimize the consistency training loss:

      LCTN=E[λ(tn)d(Fθ(x0+tn+1z,tn+1),Fθ(x0+tnz,tn))]L_\text{CT}^N = \mathbb{E}[\lambda(t_n) d(F_\theta(x_0 + t_{n+1}z, t_{n+1}), F_{\theta^-}(x_0 + t_n z, t_n))]

  3. Implementation Details:
    • Uses the same EMA target network FθF_{\theta^-} as CD.
    • Crucially uses adaptive schedules for the number of time steps N(k)N(k) and the EMA decay rate μ(k)\mu(k) (where kk is the training step). N(k)N(k) typically starts small and increases, while μ(k)\mu(k) starts high (e.g., 0.9) and approaches 1. This balances convergence speed and final quality. Appendix C provides specific schedule formulas.
    • LPIPS is also effective here.

Practical Applications & Results

  • Fast Generation: CMs achieve state-of-the-art FID scores for one-step and two-step generation on CIFAR-10 (3.55/2.93 FID) and ImageNet 64x64 (6.20/4.70 FID) when trained via CD, significantly outperforming Progressive Distillation (PD).
  • Standalone Performance: When trained via CT, CMs outperform other one-step non-adversarial methods (VAEs, Flows) and achieve results comparable to PD without needing distillation.
  • Zero-Shot Data Editing: CMs inherit the editing capabilities of diffusion models. Using variations of the multi-step sampling algorithm (Algorithm 4 in Appendix), they can perform:
    • Inpainting: Mask unknown regions and iteratively refine using the CM.
    • Colorization: Treat color channels as missing information in a transformed space (e.g., YUV or using an orthogonal basis).
    • Super-Resolution: Treat high-frequency details as missing information in a transformed space (e.g., using patch averaging and orthogonal basis).
    • Stroke-guided Editing (SDEdit): Use a stroke image as the starting point xτ1x_{\tau_1} in multi-step sampling.
    • Denoising: Apply Fθ(xσ,σ)F_\theta(x_\sigma, \sigma) directly to an image xσx_\sigma with noise level σ\sigma.
    • Interpolation: Interpolate between the initial noise vectors z1,z2z_1, z_2 (e.g., using spherical linear interpolation) and then apply Fθ(,T)F_\theta(\cdot, T).

Implementation Considerations

  • Architecture: Can reuse U-Net architectures from diffusion models (e.g., NCSN++, ADM).
  • Target Network: Using an EMA target network with stop_gradient is vital for both CD and CT.
  • Metric: LPIPS is highly recommended for image data.
  • Schedules (CT): Carefully designed adaptive schedules for NN and μ\mu are important for CT performance.
  • Computational Cost: Training cost is comparable to training diffusion models. Inference is much faster (1 network evaluation for one-step, N evaluations for N-step).

Continuous-Time Extensions

The paper also derives continuous-time versions of the CD and CT losses (Appendix B), eliminating the need for discrete time steps tnt_n. These objectives require calculating Jacobian-vector products, often necessitating forward-mode automatic differentiation, which might not be standard in all frameworks. Experimental results show they can work well, especially continuous-time CT, but may require careful initialization or variance reduction techniques.

Youtube Logo Streamline Icon: https://streamlinehq.com