Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
121 tokens/sec
GPT-4o
9 tokens/sec
Gemini 2.5 Pro Pro
47 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Diffusion Tree Search (DTS*)

Updated 1 July 2025
  • Diffusion Tree Search (DTS*) is a scalable inference algorithm applying tree search to align diffusion models with arbitrary reward functions during inference, optimizing for high-reward samples.
  • DTS* structures the denoising process as a tree, reusing all computational effort and leveraging recursive value propagation to enable anytime performance improvements.
  • Empirically, DTS* improves sample quality and achieves significant compute savings over baselines like best-of-N and SMC in image and text generation tasks.

Diffusion Tree Search (DTS^\star) is a scalable inference-time algorithm designed for aligning diffusion models to new objectives by casting the denoising sampling process as a global search or optimization problem over the space of generated trajectories. Building upon principles from Monte Carlo Tree Search (MCTS), DTS^\star achieves compute-efficient search for high-reward samples via recursive credit assignment and the reuse of all prior computational effort, enabling anytime performance improvements and broad applicability to both image and language generation tasks.

1. Theoretical Foundation and Motivation

Diffusion models sample by iteratively denoising a latent variable from noise to data, with each step governed by transitions from a pretrained generative model pθ(xt1xt)p_\theta(x_{t-1} | x_t). Traditional guidance methods attempt to align the sampling trajectory with user-specified reward functions r(x0)r(x_0) (such as classifier outputs or aesthetic reward models) directly at inference time. However, these approaches encounter significant accuracy limitations at high noise levels (high tt), suffer from inefficient use of compute by failing to reuse prior samples, and are often unable to balance global search and local exploitation effectively.

DTS^\star addresses these limitations by organizing the exponentially large set of possible denoising trajectories into a tree structure, where nodes corresponding to intermediate noisy states are augmented with value estimates and visit counts, allowing for recursive improvement of sample quality as compute increases.

2. Tree Search and Value Backup Mechanism

DTS^\star structures the denoising process as a tree, where:

  • Nodes correspond to partial denoising states (xt,t)(x_t, t).
  • Edges represent stochastic transitions pθ(xt1xt)p_\theta(x_{t-1}|x_t) from the diffusion model.
  • Terminal nodes (at t=0t=0) yield samples x0x_0 which are immediately evaluated by the reward function r(x0)r(x_0).

Upon rollout completion, DTS^\star performs recursive value propagation (soft BeLLMan backup) along the visited tree path. The value function at each node is defined by

Vt(xt)=1λlogEpθ(x0:t1xt)[exp(λr(x0))]V_t(x_t) = \frac{1}{\lambda} \log \mathbb{E}_{p_\theta(x_{0:t-1}|x_t)}[\exp(\lambda r(x_0))]

where λ\lambda is a temperature parameter. Empirically, the backup operation uses

v^(xt+1)1λlogxtC(xt+1)exp(λv^(xt)),\hat{v}(x_{t+1}) \leftarrow \frac{1}{\lambda}\log\sum_{x_t \in \mathcal{C}(x_{t+1})} \exp(\lambda \hat{v}(x_t)),

where C(xt+1)\mathcal{C}(x_{t+1}) is the set of children. This mechanism enables global credit assignment, as every trajectory contributes reward signal to all nodes traversed.

The DTS^\star variant employs greedy selection during tree traversal,

xt1=argmaxxC(xt)v^(x),x_{t-1} = \arg\max_{x' \in \mathcal{C}(x_t)} \hat{v}(x'),

yielding samples that optimize the terminal reward rather than aiming for unbiased sampling from the reward-aligned density.

3. Algorithmic Structure and Compute Efficiency

DTS^\star alternates between expansion (sampling new child nodes via pθp_\theta), selection (traversing down the maximal-value child at each node), rollout (completing incomplete trajectories, if necessary), and backup (soft BeLLMan propagation). Every visited node stores its state, value estimate, and visit count, and is reused in subsequent search iterations.

This design results in several critical features:

  • Reuse of Prior Computation: All sampled trajectories and their rewards are retained in the tree, enabling future rollouts to leverage existing partial paths. This makes DTS^\star an anytime algorithm: additional compute leads to monotonic quality improvements, with no wasted cycles.
  • Scalability and Parallelism: The algorithm supports batched model calls at each level of the tree, facilitating practical parallel implementation despite the fundamentally sequential nature of tree search.
  • Global Search for High-Reward Samples: Greedy path-following ensures that search does not merely exploit local optima but can discover globally high-volume, high-reward regions.

4. Performance and Empirical Results

DTS^\star achieves significant improvements in reward-aligned sample quality and computational efficiency across tasks:

  • On MNIST and CIFAR-10 (class-conditional generation), DTS^\star matches or outperforms Sequential Monte Carlo (SMC) and best-of-NN sampling in terms of FID and reward, while requiring up to 10×10\times less compute.
  • In text-to-image alignment and language generation, DTS^\star achieves high reward scores (as measured by CLIP similarity or classifier metrics) and matches the best-of-NN method with up to 5×5\times less compute.
  • The tree structure prevents mode collapse; DTS^\star samples reflect the high-probability "volume" under the reward-tilted density rather than focusing narrowly on a single reward spike.

A summary table comparing capabilities appears below:

Aspect DTS (Sampling) DTS^\star (Search)
Output Unbiased sampling High-reward "modes"
Selection Stochastic Greedy (max-value child)
Backup Soft BeLLMan Soft BeLLMan
Reuse Yes Yes
Turns compute into better samples Yes Yes
Overoptimization avoidance Yes Yes (volume-based regularization)

5. Applications Across Domains

DTS^\star has demonstrated robust and scalable performance in various application settings:

  • Image Generation: Efficient optimization for class-conditional sampling, reward-guided text-to-image (with reward models such as CLIP or LAION-Aesthetic), achieving strong FID, diversity, and reward scores.
  • Text Generation: Search for high-acceptability completions using masked diffusion LLMs, with higher reward and diversity compared to funneling all effort through SMC or best-of-NN.
  • General Reward Functions: Compatibility with arbitrary (including non-differentiable) reward signals at inference time, enabling flexible post-hoc alignment for creative or domain-specific objectives.

6. Scalability, Anytime Property, and Distinction from Prior Methods

DTS^\star provides anytime, scalable inference-time alignment by design. All prior tree-based computations directly improve value estimation for future samples, so each unit of additional compute yields measurable sample quality gains. This stands in contrast to SMC and best-of-NN baselines, which require fresh computation from scratch for improved results and disregard prior progress, resulting in less efficient scaling.

The algorithm's use of soft BeLLMan backups enables globally consistent value estimation, correcting for biases in high-noise regions of the diffusion chain and supporting robust search even where stepwise reward prediction is inaccurate.

7. Position in the Broader Field and Conceptual Significance

DTS^\star operationalizes inference-time alignment as a global search and incremental credit assignment problem, reflecting a trend toward integrating tree search with probabilistic and deep generative methods. By extending core ideas from MCTS to the denoising chains of diffusion models and incorporating mechanisms for value backup and computation reuse, DTS^\star offers a principled, scalable, and effective approach for sample improvement and objective alignment without retraining.

The algorithm's scalability, efficiency, and robustness to overoptimization establish it as a reference method for inference-time search and reward optimization in generative modeling. Its anytime nature enables dynamic adaptation to computational budgets and real-world constraints in diverse domains.