Diffusion Tree Search (DTS*)
- Diffusion Tree Search (DTS*) is a scalable inference algorithm applying tree search to align diffusion models with arbitrary reward functions during inference, optimizing for high-reward samples.
- DTS* structures the denoising process as a tree, reusing all computational effort and leveraging recursive value propagation to enable anytime performance improvements.
- Empirically, DTS* improves sample quality and achieves significant compute savings over baselines like best-of-N and SMC in image and text generation tasks.
Diffusion Tree Search (DTS) is a scalable inference-time algorithm designed for aligning diffusion models to new objectives by casting the denoising sampling process as a global search or optimization problem over the space of generated trajectories. Building upon principles from Monte Carlo Tree Search (MCTS), DTS achieves compute-efficient search for high-reward samples via recursive credit assignment and the reuse of all prior computational effort, enabling anytime performance improvements and broad applicability to both image and language generation tasks.
1. Theoretical Foundation and Motivation
Diffusion models sample by iteratively denoising a latent variable from noise to data, with each step governed by transitions from a pretrained generative model . Traditional guidance methods attempt to align the sampling trajectory with user-specified reward functions (such as classifier outputs or aesthetic reward models) directly at inference time. However, these approaches encounter significant accuracy limitations at high noise levels (high ), suffer from inefficient use of compute by failing to reuse prior samples, and are often unable to balance global search and local exploitation effectively.
DTS addresses these limitations by organizing the exponentially large set of possible denoising trajectories into a tree structure, where nodes corresponding to intermediate noisy states are augmented with value estimates and visit counts, allowing for recursive improvement of sample quality as compute increases.
2. Tree Search and Value Backup Mechanism
DTS structures the denoising process as a tree, where:
- Nodes correspond to partial denoising states .
- Edges represent stochastic transitions from the diffusion model.
- Terminal nodes (at ) yield samples which are immediately evaluated by the reward function .
Upon rollout completion, DTS performs recursive value propagation (soft BeLLMan backup) along the visited tree path. The value function at each node is defined by
where is a temperature parameter. Empirically, the backup operation uses
where is the set of children. This mechanism enables global credit assignment, as every trajectory contributes reward signal to all nodes traversed.
The DTS variant employs greedy selection during tree traversal,
yielding samples that optimize the terminal reward rather than aiming for unbiased sampling from the reward-aligned density.
3. Algorithmic Structure and Compute Efficiency
DTS alternates between expansion (sampling new child nodes via ), selection (traversing down the maximal-value child at each node), rollout (completing incomplete trajectories, if necessary), and backup (soft BeLLMan propagation). Every visited node stores its state, value estimate, and visit count, and is reused in subsequent search iterations.
This design results in several critical features:
- Reuse of Prior Computation: All sampled trajectories and their rewards are retained in the tree, enabling future rollouts to leverage existing partial paths. This makes DTS an anytime algorithm: additional compute leads to monotonic quality improvements, with no wasted cycles.
- Scalability and Parallelism: The algorithm supports batched model calls at each level of the tree, facilitating practical parallel implementation despite the fundamentally sequential nature of tree search.
- Global Search for High-Reward Samples: Greedy path-following ensures that search does not merely exploit local optima but can discover globally high-volume, high-reward regions.
4. Performance and Empirical Results
DTS achieves significant improvements in reward-aligned sample quality and computational efficiency across tasks:
- On MNIST and CIFAR-10 (class-conditional generation), DTS matches or outperforms Sequential Monte Carlo (SMC) and best-of- sampling in terms of FID and reward, while requiring up to less compute.
- In text-to-image alignment and language generation, DTS achieves high reward scores (as measured by CLIP similarity or classifier metrics) and matches the best-of- method with up to less compute.
- The tree structure prevents mode collapse; DTS samples reflect the high-probability "volume" under the reward-tilted density rather than focusing narrowly on a single reward spike.
A summary table comparing capabilities appears below:
Aspect | DTS (Sampling) | DTS (Search) |
---|---|---|
Output | Unbiased sampling | High-reward "modes" |
Selection | Stochastic | Greedy (max-value child) |
Backup | Soft BeLLMan | Soft BeLLMan |
Reuse | Yes | Yes |
Turns compute into better samples | Yes | Yes |
Overoptimization avoidance | Yes | Yes (volume-based regularization) |
5. Applications Across Domains
DTS has demonstrated robust and scalable performance in various application settings:
- Image Generation: Efficient optimization for class-conditional sampling, reward-guided text-to-image (with reward models such as CLIP or LAION-Aesthetic), achieving strong FID, diversity, and reward scores.
- Text Generation: Search for high-acceptability completions using masked diffusion LLMs, with higher reward and diversity compared to funneling all effort through SMC or best-of-.
- General Reward Functions: Compatibility with arbitrary (including non-differentiable) reward signals at inference time, enabling flexible post-hoc alignment for creative or domain-specific objectives.
6. Scalability, Anytime Property, and Distinction from Prior Methods
DTS provides anytime, scalable inference-time alignment by design. All prior tree-based computations directly improve value estimation for future samples, so each unit of additional compute yields measurable sample quality gains. This stands in contrast to SMC and best-of- baselines, which require fresh computation from scratch for improved results and disregard prior progress, resulting in less efficient scaling.
The algorithm's use of soft BeLLMan backups enables globally consistent value estimation, correcting for biases in high-noise regions of the diffusion chain and supporting robust search even where stepwise reward prediction is inaccurate.
7. Position in the Broader Field and Conceptual Significance
DTS operationalizes inference-time alignment as a global search and incremental credit assignment problem, reflecting a trend toward integrating tree search with probabilistic and deep generative methods. By extending core ideas from MCTS to the denoising chains of diffusion models and incorporating mechanisms for value backup and computation reuse, DTS offers a principled, scalable, and effective approach for sample improvement and objective alignment without retraining.
The algorithm's scalability, efficiency, and robustness to overoptimization establish it as a reference method for inference-time search and reward optimization in generative modeling. Its anytime nature enables dynamic adaptation to computational budgets and real-world constraints in diverse domains.