Papers
Topics
Authors
Recent
2000 character limit reached

Weak-to-Strong Diffusion with Reflection

Published 1 Feb 2025 in cs.LG and cs.CV | (2502.00473v3)

Abstract: The goal of diffusion generative models is to align the learned distribution with the real data distribution through gradient score matching. However, inherent limitations in training data quality, modeling strategies, and architectural design lead to inevitable gap between generated outputs and real data. To reduce this gap, we propose Weak-to-Strong Diffusion (W2SD), a novel framework that utilizes the estimated difference between existing weak and strong models (i.e., weak-to-strong difference) to bridge the gap between an ideal model and a strong model. By employing a reflective operation that alternates between denoising and inversion with weak-to-strong difference, we theoretically understand that W2SD steers latent variables along sampling trajectories toward regions of the real data distribution. W2SD is highly flexible and broadly applicable, enabling diverse improvements through the strategic selection of weak-to-strong model pairs (e.g., DreamShaper vs. SD1.5, good experts vs. bad experts in MoE). Extensive experiments demonstrate that W2SD significantly improves human preference, aesthetic quality, and prompt adherence, achieving SOTA performance across various modalities (e.g., image, video), architectures (e.g., UNet-based, DiT-based, MoE), and benchmarks. For example, Juggernaut-XL with W2SD can improve with the HPSv2 winning rate up to 90% over the original results. Moreover, the performance gains achieved by W2SD markedly outweigh its additional computational overhead, while the cumulative improvements from different weak-to-strong difference further solidify its practical utility and deployability.

Summary

  • The paper introduces a novel W2SD framework that enhances diffusion model inference via a reflective operation exploiting weak-to-strong model differences.
  • It combines strong model denoising with weak model inversion to approximate the ideal data distribution shift during the early high-noise inference steps.
  • Experimental results demonstrate significant improvements in generation quality, prompt adherence, and aesthetics across diverse benchmarks and modalities.

This paper introduces Weak-to-Strong Diffusion (W2SD) (2502.00473), a novel framework for enhancing the inference process of diffusion models to bridge the gap between the learned distribution and the real data distribution. The core problem addressed is that existing diffusion models, due to limitations in training data, architecture, and modeling, inevitably produce outputs that differ from the ideal data distribution. This gap, termed the "strong-to-ideal difference" (Δ2=∇pgt−∇ps\Delta_2 = \nabla p^{\mathrm{gt}} - \nabla p^{\mathrm{s}}), is difficult to minimize directly because the ground truth distribution (pgtp^{\mathrm{gt}}) is inaccessible.

The central idea of W2SD is to approximate this inaccessible strong-to-ideal difference (Δ2\Delta_2) using the empirically estimable difference between a strong model (Ms\mathcal{M}^{\mathrm{s}}, with density psp^{\mathrm{s}}) and a weak model (Mw\mathcal{M}^{\mathrm{w}}, with density pwp^{\mathrm{w}}). This "weak-to-strong difference" is defined as Δ1=∇ps−∇pw\Delta_1 = \nabla p^{\mathrm{s}} - \nabla p^{\mathrm{w}}. The hypothesis is that by leveraging Δ1\Delta_1, W2SD can steer the sampling process towards regions better aligned with the real data distribution, effectively improving the strong model.

Methodology: W2SD with Reflection

W2SD implements this idea through a "reflective operation" integrated into the diffusion sampling process. For a limited number of steps (λ\lambda) at the beginning of inference (high noise levels), the standard denoising step is modified. Given the latent state xtx_t at timestep tt:

  1. A denoising step is performed using the strong model: xt−Δt=Ms(xt,t)x_{t-\Delta t} = \mathcal{M}^{\mathrm{s}}(x_t, t).
  2. An inversion step is performed using the weak model: x~t=Minvw(xt−Δt,t)\tilde{x}_t = \mathcal{M}^{\mathrm{w}}_{\mathrm{inv}}(x_{t-\Delta t}, t). This x~t\tilde{x}_t represents the "reflected" or refined latent state at time tt.
  3. The standard denoising for the next step uses this refined state: xt−1=Ms(x~t,t)x_{t-1} = \mathcal{M}^{\mathrm{s}}(\tilde{x}_t, t).

Algorithm:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Algorithm 1: W2SD
Input: Strong Model Ms, Weak Model Mw, Total Steps T, Optimization Steps λ
Output: Clean Data x0
Sample Gaussian noise xT
for t = T down to 1:
  xt_intermediate = xt  # Temporary variable for clarity
  if t > T - λ:
    # Apply Reflection
    x_denoised_strong = Ms(xt, t)
    xt_intermediate = Minv_w(x_denoised_strong, t)
  # Standard strong model denoising step
  x_{t-1} = Ms(xt_intermediate, t)
  xt = x_{t-1} # Prepare for next iteration (implied)
return x0
(Note: The paper's Algorithm 1 structure slightly differs in variable naming/flow but the core operation is the reflection followed by strong model denoising)

Theoretical Understanding (Theorem 1)

The paper provides theoretical justification (Theorem 1) showing that the reflection operation x~t=Minvw(Ms(xt,t),t)\tilde{x}_t = \mathcal{M}^{\mathrm{w}}_{\mathrm{inv}}(\mathcal{M}^{\mathrm{s}}(x_t, t), t) modifies the latent state xtx_t approximately as:

x~t≈xt+σ2tΔt(∇xtpts(xt)−∇xtptw(xt))\tilde{x}_t \approx x_t + \sigma^{2t}\Delta t (\nabla_{x_t} p_t^{\mathrm{s}}(x_t) - \nabla_{x_t} p_t^{\mathrm{w}}(x_t))

This confirms that the reflection perturbs the latent variable xtx_t in the direction of the weak-to-strong difference (Δ1(t)\Delta_1(t)). If Δ1\Delta_1 approximates Δ2\Delta_2, this perturbation pushes the latent variable towards the ideal data distribution.

Flexibility and Applications: Defining Weak-to-Strong Pairs

A key strength of W2SD is its flexibility. The choice of "weak" and "strong" models allows tailoring the enhancement effect. The paper explores several types of differences:

  1. Weight Difference:
    • Full Fine-tuning: Ms\mathcal{M}^{\mathrm{s}} = Fine-tuned model (e.g., DreamShaper, Juggernaut-XL), Mw\mathcal{M}^{\mathrm{w}} = Base model (e.g., SD1.5, SDXL). Improves human preference (HPSv2, PickScore) and aesthetics (AES).
    • LoRA: Ms\mathcal{M}^{\mathrm{s}} = Base model + LoRA, Mw\mathcal{M}^{\mathrm{w}} = Base model. Enhances personalization and style adherence (CLIP-I, CLIP-T).
    • MoE: Ms\mathcal{M}^{\mathrm{s}} = Top-k experts, Mw\mathcal{M}^{\mathrm{w}} = Bottom-k experts (e.g., in DiT-MoE). Improves overall quality and reduces artifacts (FID, IS).
  2. Condition Difference:
    • Guidance Scale: Ms\mathcal{M}^{\mathrm{s}} = High CFG scale, Mw\mathcal{M}^{\mathrm{w}} = Low/Zero CFG scale. Improves prompt adherence and human preference. Z-Sampling (Bai et al., 2024) is identified as a special case.
    • Prompt Semantics: Ms\mathcal{M}^{\mathrm{s}} = Model with refined/detailed prompt, Mw\mathcal{M}^{\mathrm{w}} = Model with raw/simple prompt. Improves generation based on detailed prompts.
  3. Sampling Pipeline Difference:
    • Enhanced Pipelines: Ms\mathcal{M}^{\mathrm{s}} = Pipeline with added control (e.g., ControlNet, IP-Adapter), Mw\mathcal{M}^{\mathrm{w}} = Standard pipeline (e.g., DDIM). Improves adherence to the specific control mechanism (e.g., edge maps, reference image style).

Experimental Validation

  • Synthetic Data (Gaussian Mixtures): Visualizations confirm W2SD steers sampling trajectories towards underrepresented data modes.
  • Real Data (CIFAR-10 subset): Demonstrates balancing generation ratios for classes underrepresented in the weak model's training data.
  • Benchmarks (Pick-a-Pic, DrawBench, GenEval, ImageNet, VBench): Extensive quantitative results show significant improvements in human preference (HPSv2, PickScore, MPS), aesthetics (AES), prompt alignment, personalization (CLIP scores), and generation quality (FID, IS) across image and video modalities.
  • Ablation Studies:
    • Confirm performance gains depend on Ms\mathcal{M}^{\mathrm{s}} being genuinely "stronger" than Mw\mathcal{M}^{\mathrm{w}}.
    • Show W2SD provides better quality than standard sampling for the same computational budget.
    • Demonstrate that the improvements from different types of weak-to-strong differences (e.g., weight and condition) can be cumulative.
    • Analyze the impact of inversion approximation errors.
    • Connect W2SD to prior work like Re-Sampling (Akyürek et al., 2022), framing it as a specific instance where the weak inversion is approximated by adding random noise.

Contributions and Conclusion

  • First systematic integration of the weak-to-strong mechanism for inference enhancement in diffusion models.
  • Proposal of the W2SD framework using a reflective operation based on weak model inversion.
  • Theoretical insight linking the reflection to the weak-to-strong gradient difference.
  • Demonstration of W2SD's flexibility and broad applicability across various model pairs, tasks, and modalities.
  • Achieving state-of-the-art performance improvements with minimal computational overhead.
  • Providing a unifying perspective on several existing inference enhancement techniques.

W2SD presents a general and effective approach to improve diffusion model outputs by leveraging the differences between readily available models, pushing the generated distribution closer to the desired real data distribution without needing direct access to it.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 3 tweets with 1 like about this paper.