Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
143 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
46 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Inference-time Scaling of Diffusion Models through Classical Search (2505.23614v1)

Published 29 May 2025 in cs.LG and stat.ML

Abstract: Classical search algorithms have long underpinned modern artificial intelligence. In this work, we tackle the challenge of inference-time control in diffusion models -- adapting generated outputs to meet diverse test-time objectives -- using principles from classical search. We propose a general framework that orchestrates local and global search to efficiently navigate the generative space. It employs a theoretically grounded local search via annealed Langevin MCMC and performs compute-efficient global exploration using breadth-first and depth-first tree search. We evaluate our approach on a range of challenging domains, including planning, offline reinforcement learning, and image generation. Across all tasks, we observe significant gains in both performance and efficiency. These results show that classical search provides a principled and practical foundation for inference-time scaling in diffusion models. Project page at diffusion-inference-scaling.github.io.

Summary

  • The paper introduces a framework that integrates classical search algorithms with a verifier function to guide diffusion model outputs.
  • It employs local search via annealed Langevin MCMC to refine samples and unifies gradient-based guidance with recurrence strategies.
  • Global search strategies (BFS and DFS) are used to explore diverse modes, significantly boosting performance and efficiency in various applications.

This paper introduces a framework for improving the outputs of diffusion models at inference time by incorporating classical search algorithms. The core idea is to adapt the generation process to meet specific objectives defined by a "verifier" function, f(x0)f(\mathbf{x}_0), which quantifies the quality of a generated sample x0\mathbf{x}_0. The goal is to sample from a modified distribution $\Tilde{p}_0(\mathbf{x}_0) \propto p_0(\mathbf{x}_0) f(\mathbf{x}_0)^{\lambda}$, where p0(x0)p_0(\mathbf{x}_0) is the original diffusion model's distribution and λ\lambda controls the influence of the verifier.

The proposed framework combines local and global search strategies:

Local Search via Annealed Langevin MCMC

To refine samples locally, the paper proposes using Langevin MCMC. This method explores the neighborhood of a current sample, guided by the gradient of the verifier function and the score function of the diffusion model. A key theoretical contribution is unifying prior gradient-based guidance techniques. The paper shows that training-free guidance with a recurrence strategy (Eq. 5) is, in the continuous limit, equivalent to performing Langevin MCMC on a sequence of annealed distributions $\Tilde{q}_t(\mathbf{x}_t) \propto q_t(\mathbf{x}_t) \hat{f}_t(\mathbf{x}_t)$ (Proposition 1). Here, qt(xt)q_t(\mathbf{x}_t) is the noisy distribution at timestep tt, and f^t(xt)\hat{f}_t(\mathbf{x}_t) is an annealed version of the verifier. The recurrence steps (repeatedly denoising slightly and then re-noising) act like MCMC steps, pulling samples back towards the data manifold, while the guidance term biases the sampling towards regions favored by the verifier.

Implementation:

  • Local search is implemented by modifying the reverse transition kernel $\Tilde{p}_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_t)$ to include a sequence of Langevin MCMC steps before the standard denoising step (DDIM or DDPM).
  • Hyperparameters for this local search (e.g., guidance strength, number of recurrence steps NrecurN_{\text{recur}}) can be tuned using the design space from "Training-Free Guidance" (TFG) [ye2024tfg], as detailed in Appendix A.3. An example algorithm (Algorithm 3) outlines this process, showing how guidance terms Δvar\bm{\Delta}_{\text{var}} (variance guidance) and Δmean\bm{\Delta}_{\text{mean}} (mean guidance) are computed and applied within recurrence loops.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def local_search_step(x_t, t, model_epsilon_theta, verifier_f, hparams):
    # N_recur: number of Langevin MCMC steps
    for _ in range(hparams.N_recur):
        # Estimate x_0_hat from x_t
        x_0_hat = (x_t - model_epsilon_theta(x_t, t) * sigma_t) / alpha_t

        # Compute guidance gradient (simplified)
        # grad_log_f = gradient of verifier_f w.r.t. x_0_hat or x_t
        grad_log_f = compute_verifier_gradient(verifier_f, x_0_hat, x_t, t, hparams)
        delta_guidance = hparams.rho_t * grad_log_f # Simplified example

        # Denoise one step (e.g., DDIM)
        x_t_minus_1_pred = ddim_step(x_t, model_epsilon_theta, t)

        # Apply guidance
        x_t_minus_1_guided = x_t_minus_1_pred + (alpha_t_minus_1 / alpha_t) * delta_guidance

        # Re-noise to x_t (recurrence)
        if _ < hparams.N_recur - 1: # Not the last MCMC step
             x_t = forward_process_step(x_t_minus_1_guided, t-1, t) # q(x_t | x_t-1)
        else:
            x_t_minus_1 = x_t_minus_1_guided

    return x_t_minus_1

Global Search for Mode Identification

Since local search can get stuck in local optima, global search strategies are introduced to explore diverse modes in the generative landscape. The denoising process is viewed as a search tree, where nodes are states xt\mathbf{x}_t and edges are transitions $\Tilde{p}_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_t)$.

1. BFS-style Linear Search (Algorithm 4):

These methods denoise a set of KtK_t particles (samples) in parallel.

  • Best-of-N: Generate NN complete trajectories and pick the best one based on the final verifier score f(x0)f(\mathbf{x}_0). This is computationally expensive as it doesn't use intermediate information.
  • BFS-Resampling: At certain timesteps tSt \in \mathcal{S}, evaluate intermediate particles xtk\mathbf{x}_t^k using f(x0tk)f(\mathbf{x}_{0|t}^k). Then, resample children for the next step, allocating more children ntkn_t^k to particles with higher scores: ntkf(x0tk)τtn_t^k \propto f(\mathbf{x}_{0|t}^k)^{\tau_t}. τt\tau_t is a temperature parameter.
  • BFS-Pruning: Similar to resampling, but if using deterministic samplers (where multiple children from one parent are identical), this method prunes low-scoring particles, keeping only the top ones. Each parent has at most one child.

Implementation (BFS):

  • Maintain a set of KtK_t active particles at each denoising step tt.
  • At predefined evaluation timesteps S\mathcal{S} (e.g., T/2,T/4T/2, T/4):
    • For each particle xtk\mathbf{x}_t^k, predict x0tk=(xtkσtϵθ(xtk,t))/αt\mathbf{x}_{0|t}^k = (\mathbf{x}_t^k - \sigma_t \epsilon_{\theta}(\mathbf{x}_t^k, t)) / \alpha_t.
    • Calculate verifier score f(x0tk)f(\mathbf{x}_{0|t}^k).
    • For BFS-Resampling, calculate the number of children ntkn_t^k for each particle based on its normalized score raised to power τt\tau_t.
    • For BFS-Pruning, ntkn_t^k is either 0 or 1.
    • Generate ntkn_t^k children for each xtk\mathbf{x}_t^k using the (potentially local-search-enhanced) transition $\Tilde{p}_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_t^k)$.
    • Kt1=ntkK_{t-1} = \sum n_t^k.
  • The temperature τt\tau_t and evaluation timesteps S\mathcal{S} are key hyperparameters. An "increase" schedule for τt\tau_t (higher τt\tau_t for lower noise levels) is suggested.

2. DFS-style Non-linear Search (Algorithm 5):

This method denoises a single particle iteratively.

  • If the verifier score f(x0t)f(\mathbf{x}_{0|t}) at an intermediate step tt drops below a scheduled threshold δt\delta_t, the algorithm backtracks.
  • Backtracking involves re-noising the current sample xt\mathbf{x}_t to a previous, higher noise level tnext=t+ΔTt_{\text{next}} = t + \Delta_T using the forward diffusion process q(xtnextxt)q(\mathbf{x}_{t_{\text{next}}}|\mathbf{x}_t).
  • The search then continues from this new state. A budget limits the total number of backtracks.

Implementation (DFS):

  • Denoise a single particle xt\mathbf{x}_t.
  • At evaluation timesteps S\mathcal{S}:
    • Predict x0t\mathbf{x}_{0|t} and compute f(x0t)f(\mathbf{x}_{0|t}).
    • If f(x0t)<δtf(\mathbf{x}_{0|t}) < \delta_t and backtrack budget B>0B > 0:
    • Store current (xt,f(x0t))(\mathbf{x}_t, f(\mathbf{x}_{0|t})) in a buffer.
    • Set tmin(t+ΔT,T)t \leftarrow \min(t + \Delta_T, T).
    • Sample xtq(xtxoriginal tΔT)\mathbf{x}_t \sim q(\mathbf{x}_t | \mathbf{x}_{\text{original } t - \Delta_T}).
    • Decrement BB.
    • Else (score is good, or no budget):
    • If score is bad and B=0B=0, retrieve best sample from buffer for current tt.
    • Proceed to tt1t \leftarrow t-1 by sampling $\mathbf{x}_{t-1} \sim \Tilde{p}_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
  • Key hyperparameters: backtracking depth ΔT\Delta_T, threshold schedule δt\delta_t, evaluation timesteps S\mathcal{S}. An "increase" schedule for δt\delta_t is suggested. A buffer stores prior results to retrieve the best sample if no path satisfies constraints.

Experimental Applications

The framework is evaluated on three diverse tasks:

  1. Long Horizon Planning (PointMaze):
    • Task: Generate trajectories for an agent to navigate complex mazes ("Giant" and "Ultra"). The trajectory τ=[(s1,a1),,(sH,aH)]\bm{\tau} = [(\mathbf{s}_1, \mathbf{a}_1), \dots, (\mathbf{s}_H, \mathbf{a}_H)] is denoised.
    • Verifier: A physics-based world model. The score f(τ)f(\bm{\tau}) is high if the trajectory does not collide with walls. Specifically, it's eL(xi,yi)2e^{-\sum L(x_i, y_i)^2}, where L(xi,yi)L(x_i, y_i) is the squared distance to the nearest wall for colliding points.
    • Results: Local search helps generate collision-free plans. Global search (BFS/DFS) significantly improves success rates (over 90%) and compute efficiency compared to Best-of-N, and produces diverse solutions. For example, local search steps were set to 2 for Giant and 6 for Ultra. Global BFS evaluated at steps {12,8,4}\{12, 8, 4\} (out of 16 total steps).
  2. Offline Reinforcement Learning (D4RL):
    • Task: Improve a diffusion-based policy μ(as)\mu(\mathbf{a}|\mathbf{s}) using a pretrained Q-function Qψ(s,a)Q_{\psi}(\mathbf{s},\mathbf{a}) at test time. The goal is to sample actions from π(as)μ(as)eβQψ(s,a)\pi^*(\mathbf{a}|\mathbf{s}) \propto \mu(\mathbf{a}|\mathbf{s})e^{\beta Q_{\psi}(\mathbf{s},\mathbf{a})}.
    • Verifier: The pretrained Q-function, f(a)=eβQψ(s,a)f(\mathbf{a}) = e^{\beta Q_{\psi}(\mathbf{s},\mathbf{a})}.
    • Results: This "Test-Time Search" (TTS) approach, using local and global search, achieves performance competitive with state-of-the-art training-based methods on D4RL locomotion benchmarks, without requiring retraining of the diffusion policy or Q-function. This is particularly useful when QQ-functions are large foundation models. Hyperparameters like ρˉ,μˉ\bar{\rho}, \bar{\mu} for local search (Eq. 7, 8 in Appendix A.3) were tuned.
  3. Image Generation:
    • Compositional Text-to-Image (CompBench with SSD-1B model):
      • Task: Generate images based on complex text prompts (e.g., attribute binding, object relationships).
      • Verifiers: Non-differentiable oracle verifiers (BLIP-VQA for attributes, UniDet for relationships). Only global search is applied.
      • Results (DFS): DFS outperforms Best-of-N with up to 2x less compute. It adaptively allocates more compute to harder prompts (those with initially lower verifier scores). For DFS, evaluation occurred at timesteps {25,35,45}\{25, 35, 45\} out of 50, with backtrack depth ΔT=25\Delta_T=25.
    • Conditional ImageNet Generation (DDPM model):
      • Task: Generate class-conditional images using an unconditional diffusion model, guided by a classifier.
      • Verifier: A pretrained image classifier (ViT).
      • Challenge & Solution: Standard classifier guidance is prone to "verifier hacking" (adversarial examples). The paper proposes a "double-verifier" setup: one ViT (e.g., 224px) for local search gradients and a different ViT (e.g., 384px) for global search resampling/pruning decisions.
      • Results (BFS): The double-verifier setup consistently improves FID and accuracy. BFS-Resampling and BFS-Pruning are more compute-efficient than Best-of-N. BFS evaluation steps were {75,50,25}\{75, 50, 25\} out of 100.

Limitations

  • Hyperparameter Tuning: The proposed search methods introduce new hyperparameters that require tuning (e.g., guidance strengths, search schedules, thresholds, temperatures).
  • Verifier Hacking: Generated samples might exploit weaknesses in the verifier, leading to high scores for low-quality outputs. While double-verifier helps, this remains a concern.

Conclusion

The paper presents a unified framework integrating classical search algorithms (Langevin MCMC, BFS, DFS) into the inference process of diffusion models. This "inference-time scaling" significantly enhances performance and compute efficiency across planning, offline RL, and image generation tasks by strategically navigating the generative space to find samples that better align with desired objectives. The theoretical unification of recurrence-based guidance with Langevin MCMC provides a principled understanding of local search mechanisms.

X Twitter Logo Streamline Icon: https://streamlinehq.com