- The paper introduces a novel W2SD framework that enhances diffusion model inference via a reflective operation exploiting weak-to-strong model differences.
- It combines strong model denoising with weak model inversion to approximate the ideal data distribution shift during the early high-noise inference steps.
- Experimental results demonstrate significant improvements in generation quality, prompt adherence, and aesthetics across diverse benchmarks and modalities.
This paper introduces Weak-to-Strong Diffusion (W2SD) (2502.00473), a novel framework for enhancing the inference process of diffusion models to bridge the gap between the learned distribution and the real data distribution. The core problem addressed is that existing diffusion models, due to limitations in training data, architecture, and modeling, inevitably produce outputs that differ from the ideal data distribution. This gap, termed the "strong-to-ideal difference" (Δ2​=∇pgt−∇ps), is difficult to minimize directly because the ground truth distribution (pgt) is inaccessible.
The central idea of W2SD is to approximate this inaccessible strong-to-ideal difference (Δ2​) using the empirically estimable difference between a strong model (Ms, with density ps) and a weak model (Mw, with density pw). This "weak-to-strong difference" is defined as Δ1​=∇ps−∇pw. The hypothesis is that by leveraging Δ1​, W2SD can steer the sampling process towards regions better aligned with the real data distribution, effectively improving the strong model.
Methodology: W2SD with Reflection
W2SD implements this idea through a "reflective operation" integrated into the diffusion sampling process. For a limited number of steps (λ) at the beginning of inference (high noise levels), the standard denoising step is modified. Given the latent state xt​ at timestep t:
- A denoising step is performed using the strong model: xt−Δt​=Ms(xt​,t).
- An inversion step is performed using the weak model: x~t​=Minvw​(xt−Δt​,t). This x~t​ represents the "reflected" or refined latent state at time t.
- The standard denoising for the next step uses this refined state: xt−1​=Ms(x~t​,t).
Algorithm:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
Algorithm 1: W2SD
Input: Strong Model Ms, Weak Model Mw, Total Steps T, Optimization Steps λ
Output: Clean Data x0
Sample Gaussian noise xT
for t = T down to 1:
xt_intermediate = xt # Temporary variable for clarity
if t > T - λ:
# Apply Reflection
x_denoised_strong = Ms(xt, t)
xt_intermediate = Minv_w(x_denoised_strong, t)
# Standard strong model denoising step
x_{t-1} = Ms(xt_intermediate, t)
xt = x_{t-1} # Prepare for next iteration (implied)
return x0 |
(Note: The paper's Algorithm 1 structure slightly differs in variable naming/flow but the core operation is the reflection followed by strong model denoising)
Theoretical Understanding (Theorem 1)
The paper provides theoretical justification (Theorem 1) showing that the reflection operation x~t​=Minvw​(Ms(xt​,t),t) modifies the latent state xt​ approximately as:
x~t​≈xt​+σ2tΔt(∇xt​​pts​(xt​)−∇xt​​ptw​(xt​))
This confirms that the reflection perturbs the latent variable xt​ in the direction of the weak-to-strong difference (Δ1​(t)). If Δ1​ approximates Δ2​, this perturbation pushes the latent variable towards the ideal data distribution.
Flexibility and Applications: Defining Weak-to-Strong Pairs
A key strength of W2SD is its flexibility. The choice of "weak" and "strong" models allows tailoring the enhancement effect. The paper explores several types of differences:
- Weight Difference:
- Full Fine-tuning: Ms = Fine-tuned model (e.g., DreamShaper, Juggernaut-XL), Mw = Base model (e.g., SD1.5, SDXL). Improves human preference (HPSv2, PickScore) and aesthetics (AES).
- LoRA: Ms = Base model + LoRA, Mw = Base model. Enhances personalization and style adherence (CLIP-I, CLIP-T).
- MoE: Ms = Top-k experts, Mw = Bottom-k experts (e.g., in DiT-MoE). Improves overall quality and reduces artifacts (FID, IS).
- Condition Difference:
- Guidance Scale: Ms = High CFG scale, Mw = Low/Zero CFG scale. Improves prompt adherence and human preference. Z-Sampling (Bai et al., 2024) is identified as a special case.
- Prompt Semantics: Ms = Model with refined/detailed prompt, Mw = Model with raw/simple prompt. Improves generation based on detailed prompts.
- Sampling Pipeline Difference:
- Enhanced Pipelines: Ms = Pipeline with added control (e.g., ControlNet, IP-Adapter), Mw = Standard pipeline (e.g., DDIM). Improves adherence to the specific control mechanism (e.g., edge maps, reference image style).
Experimental Validation
- Synthetic Data (Gaussian Mixtures): Visualizations confirm W2SD steers sampling trajectories towards underrepresented data modes.
- Real Data (CIFAR-10 subset): Demonstrates balancing generation ratios for classes underrepresented in the weak model's training data.
- Benchmarks (Pick-a-Pic, DrawBench, GenEval, ImageNet, VBench): Extensive quantitative results show significant improvements in human preference (HPSv2, PickScore, MPS), aesthetics (AES), prompt alignment, personalization (CLIP scores), and generation quality (FID, IS) across image and video modalities.
- Ablation Studies:
- Confirm performance gains depend on Ms being genuinely "stronger" than Mw.
- Show W2SD provides better quality than standard sampling for the same computational budget.
- Demonstrate that the improvements from different types of weak-to-strong differences (e.g., weight and condition) can be cumulative.
- Analyze the impact of inversion approximation errors.
- Connect W2SD to prior work like Re-Sampling (Akyürek et al., 2022), framing it as a specific instance where the weak inversion is approximated by adding random noise.
Contributions and Conclusion
- First systematic integration of the weak-to-strong mechanism for inference enhancement in diffusion models.
- Proposal of the W2SD framework using a reflective operation based on weak model inversion.
- Theoretical insight linking the reflection to the weak-to-strong gradient difference.
- Demonstration of W2SD's flexibility and broad applicability across various model pairs, tasks, and modalities.
- Achieving state-of-the-art performance improvements with minimal computational overhead.
- Providing a unifying perspective on several existing inference enhancement techniques.
W2SD presents a general and effective approach to improve diffusion model outputs by leveraging the differences between readily available models, pushing the generated distribution closer to the desired real data distribution without needing direct access to it.