Inference-time Scaling of Diffusion Models through Classical Search (2505.23614v1)
Abstract: Classical search algorithms have long underpinned modern artificial intelligence. In this work, we tackle the challenge of inference-time control in diffusion models -- adapting generated outputs to meet diverse test-time objectives -- using principles from classical search. We propose a general framework that orchestrates local and global search to efficiently navigate the generative space. It employs a theoretically grounded local search via annealed Langevin MCMC and performs compute-efficient global exploration using breadth-first and depth-first tree search. We evaluate our approach on a range of challenging domains, including planning, offline reinforcement learning, and image generation. Across all tasks, we observe significant gains in both performance and efficiency. These results show that classical search provides a principled and practical foundation for inference-time scaling in diffusion models. Project page at diffusion-inference-scaling.github.io.
Summary
- The paper introduces a framework that integrates classical search algorithms with a verifier function to guide diffusion model outputs.
- It employs local search via annealed Langevin MCMC to refine samples and unifies gradient-based guidance with recurrence strategies.
- Global search strategies (BFS and DFS) are used to explore diverse modes, significantly boosting performance and efficiency in various applications.
This paper introduces a framework for improving the outputs of diffusion models at inference time by incorporating classical search algorithms. The core idea is to adapt the generation process to meet specific objectives defined by a "verifier" function, f(x0), which quantifies the quality of a generated sample x0. The goal is to sample from a modified distribution $\Tilde{p}_0(\mathbf{x}_0) \propto p_0(\mathbf{x}_0) f(\mathbf{x}_0)^{\lambda}$, where p0(x0) is the original diffusion model's distribution and λ controls the influence of the verifier.
The proposed framework combines local and global search strategies:
Local Search via Annealed Langevin MCMC
To refine samples locally, the paper proposes using Langevin MCMC. This method explores the neighborhood of a current sample, guided by the gradient of the verifier function and the score function of the diffusion model. A key theoretical contribution is unifying prior gradient-based guidance techniques. The paper shows that training-free guidance with a recurrence strategy (Eq. 5) is, in the continuous limit, equivalent to performing Langevin MCMC on a sequence of annealed distributions $\Tilde{q}_t(\mathbf{x}_t) \propto q_t(\mathbf{x}_t) \hat{f}_t(\mathbf{x}_t)$ (Proposition 1). Here, qt(xt) is the noisy distribution at timestep t, and f^t(xt) is an annealed version of the verifier. The recurrence steps (repeatedly denoising slightly and then re-noising) act like MCMC steps, pulling samples back towards the data manifold, while the guidance term biases the sampling towards regions favored by the verifier.
Implementation:
- Local search is implemented by modifying the reverse transition kernel $\Tilde{p}_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_t)$ to include a sequence of Langevin MCMC steps before the standard denoising step (DDIM or DDPM).
- Hyperparameters for this local search (e.g., guidance strength, number of recurrence steps Nrecur) can be tuned using the design space from "Training-Free Guidance" (TFG) [ye2024tfg], as detailed in Appendix A.3. An example algorithm (Algorithm 3) outlines this process, showing how guidance terms Δvar (variance guidance) and Δmean (mean guidance) are computed and applied within recurrence loops.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
def local_search_step(x_t, t, model_epsilon_theta, verifier_f, hparams): # N_recur: number of Langevin MCMC steps for _ in range(hparams.N_recur): # Estimate x_0_hat from x_t x_0_hat = (x_t - model_epsilon_theta(x_t, t) * sigma_t) / alpha_t # Compute guidance gradient (simplified) # grad_log_f = gradient of verifier_f w.r.t. x_0_hat or x_t grad_log_f = compute_verifier_gradient(verifier_f, x_0_hat, x_t, t, hparams) delta_guidance = hparams.rho_t * grad_log_f # Simplified example # Denoise one step (e.g., DDIM) x_t_minus_1_pred = ddim_step(x_t, model_epsilon_theta, t) # Apply guidance x_t_minus_1_guided = x_t_minus_1_pred + (alpha_t_minus_1 / alpha_t) * delta_guidance # Re-noise to x_t (recurrence) if _ < hparams.N_recur - 1: # Not the last MCMC step x_t = forward_process_step(x_t_minus_1_guided, t-1, t) # q(x_t | x_t-1) else: x_t_minus_1 = x_t_minus_1_guided return x_t_minus_1 |
Global Search for Mode Identification
Since local search can get stuck in local optima, global search strategies are introduced to explore diverse modes in the generative landscape. The denoising process is viewed as a search tree, where nodes are states xt and edges are transitions $\Tilde{p}_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
1. BFS-style Linear Search (Algorithm 4):
These methods denoise a set of Kt particles (samples) in parallel.
- Best-of-N: Generate N complete trajectories and pick the best one based on the final verifier score f(x0). This is computationally expensive as it doesn't use intermediate information.
- BFS-Resampling: At certain timesteps t∈S, evaluate intermediate particles xtk using f(x0∣tk). Then, resample children for the next step, allocating more children ntk to particles with higher scores: ntk∝f(x0∣tk)τt. τt is a temperature parameter.
- BFS-Pruning: Similar to resampling, but if using deterministic samplers (where multiple children from one parent are identical), this method prunes low-scoring particles, keeping only the top ones. Each parent has at most one child.
Implementation (BFS):
- Maintain a set of Kt active particles at each denoising step t.
- At predefined evaluation timesteps S (e.g., T/2,T/4):
- For each particle xtk, predict x0∣tk=(xtk−σtϵθ(xtk,t))/αt.
- Calculate verifier score f(x0∣tk).
- For BFS-Resampling, calculate the number of children ntk for each particle based on its normalized score raised to power τt.
- For BFS-Pruning, ntk is either 0 or 1.
- Generate ntk children for each xtk using the (potentially local-search-enhanced) transition $\Tilde{p}_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_t^k)$.
- Kt−1=∑ntk.
- The temperature τt and evaluation timesteps S are key hyperparameters. An "increase" schedule for τt (higher τt for lower noise levels) is suggested.
2. DFS-style Non-linear Search (Algorithm 5):
This method denoises a single particle iteratively.
- If the verifier score f(x0∣t) at an intermediate step t drops below a scheduled threshold δt, the algorithm backtracks.
- Backtracking involves re-noising the current sample xt to a previous, higher noise level tnext=t+ΔT using the forward diffusion process q(xtnext∣xt).
- The search then continues from this new state. A budget limits the total number of backtracks.
Implementation (DFS):
- Denoise a single particle xt.
- At evaluation timesteps S:
- Predict x0∣t and compute f(x0∣t).
- If f(x0∣t)<δt and backtrack budget B>0:
- Store current (xt,f(x0∣t)) in a buffer.
- Set t←min(t+ΔT,T).
- Sample xt∼q(xt∣xoriginal t−ΔT).
- Decrement B.
- Else (score is good, or no budget):
- If score is bad and B=0, retrieve best sample from buffer for current t.
- Proceed to t←t−1 by sampling $\mathbf{x}_{t-1} \sim \Tilde{p}_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
- Key hyperparameters: backtracking depth ΔT, threshold schedule δt, evaluation timesteps S. An "increase" schedule for δt is suggested. A buffer stores prior results to retrieve the best sample if no path satisfies constraints.
Experimental Applications
The framework is evaluated on three diverse tasks:
- Long Horizon Planning (PointMaze):
- Task: Generate trajectories for an agent to navigate complex mazes ("Giant" and "Ultra"). The trajectory τ=[(s1,a1),…,(sH,aH)] is denoised.
- Verifier: A physics-based world model. The score f(τ) is high if the trajectory does not collide with walls. Specifically, it's e−∑L(xi,yi)2, where L(xi,yi) is the squared distance to the nearest wall for colliding points.
- Results: Local search helps generate collision-free plans. Global search (BFS/DFS) significantly improves success rates (over 90%) and compute efficiency compared to Best-of-N, and produces diverse solutions. For example, local search steps were set to 2 for Giant and 6 for Ultra. Global BFS evaluated at steps {12,8,4} (out of 16 total steps).
- Offline Reinforcement Learning (D4RL):
- Task: Improve a diffusion-based policy μ(a∣s) using a pretrained Q-function Qψ(s,a) at test time. The goal is to sample actions from π∗(a∣s)∝μ(a∣s)eβQψ(s,a).
- Verifier: The pretrained Q-function, f(a)=eβQψ(s,a).
- Results: This "Test-Time Search" (TTS) approach, using local and global search, achieves performance competitive with state-of-the-art training-based methods on D4RL locomotion benchmarks, without requiring retraining of the diffusion policy or Q-function. This is particularly useful when Q-functions are large foundation models. Hyperparameters like ρˉ,μˉ for local search (Eq. 7, 8 in Appendix A.3) were tuned.
- Image Generation:
- Compositional Text-to-Image (CompBench with SSD-1B model):
- Task: Generate images based on complex text prompts (e.g., attribute binding, object relationships).
- Verifiers: Non-differentiable oracle verifiers (BLIP-VQA for attributes, UniDet for relationships). Only global search is applied.
- Results (DFS): DFS outperforms Best-of-N with up to 2x less compute. It adaptively allocates more compute to harder prompts (those with initially lower verifier scores). For DFS, evaluation occurred at timesteps {25,35,45} out of 50, with backtrack depth ΔT=25.
- Conditional ImageNet Generation (DDPM model):
- Task: Generate class-conditional images using an unconditional diffusion model, guided by a classifier.
- Verifier: A pretrained image classifier (ViT).
- Challenge & Solution: Standard classifier guidance is prone to "verifier hacking" (adversarial examples). The paper proposes a "double-verifier" setup: one ViT (e.g., 224px) for local search gradients and a different ViT (e.g., 384px) for global search resampling/pruning decisions.
- Results (BFS): The double-verifier setup consistently improves FID and accuracy. BFS-Resampling and BFS-Pruning are more compute-efficient than Best-of-N. BFS evaluation steps were {75,50,25} out of 100.
- Compositional Text-to-Image (CompBench with SSD-1B model):
Limitations
- Hyperparameter Tuning: The proposed search methods introduce new hyperparameters that require tuning (e.g., guidance strengths, search schedules, thresholds, temperatures).
- Verifier Hacking: Generated samples might exploit weaknesses in the verifier, leading to high scores for low-quality outputs. While double-verifier helps, this remains a concern.
Conclusion
The paper presents a unified framework integrating classical search algorithms (Langevin MCMC, BFS, DFS) into the inference process of diffusion models. This "inference-time scaling" significantly enhances performance and compute efficiency across planning, offline RL, and image generation tasks by strategically navigating the generative space to find samples that better align with desired objectives. The theoretical unification of recurrence-based guidance with Langevin MCMC provides a principled understanding of local search mechanisms.