- The paper presents the novel φ-Decoding method that enhances LLM reasoning by simulating future steps and balancing exploration with exploitation.
- It employs dynamic advantage estimation and alignment assessment via clustering to re-weight candidate steps based on simulated future rewards.
- The approach integrates dynamic pruning strategies to optimize computational cost while achieving significant performance gains on benchmarks like GSM8K and AIME.
This paper introduces φ-Decoding, a novel inference-time optimization algorithm designed to improve the reasoning capabilities of LLMs by balancing exploration and exploitation more effectively than previous methods. It addresses the limitations of standard auto-regressive decoding (which is short-sighted) and search-based methods like Tree-of-Thought (ToT) or Monte Carlo Tree Search (MCTS) (which can involve excessive exploration in vast search spaces).
The core idea is adaptive foresight sampling. Instead of just looking at the past steps a<t, φ-Decoding estimates the value of taking a potential next step at by simulating future steps a>t and evaluating their quality. The probability of selecting a step a^t is adjusted based on an estimated reward function R derived from these future simulations:
a^t∼pθ(at∣x,a<t)exp[R(x,a≤t,a>t)/τ]
The key innovation lies in how the step value estimation function R is constructed. It combines two complementary perspectives:
- Dynamic Advantage Estimation (At): This estimates the absolute benefit of a candidate step at. It's calculated as the difference in the average log probability of the foresight path starting from at (Ft) compared to the foresight path from the previous step at−1 (Ft−1).
At=Ft−Ft−1
where Ft=pθ(a>t∣x,at,a≤t), implemented using the average log probability of the sequence to mitigate length bias. This captures the uncertainty or confidence gain provided by the step.
- Alignment Assessment by Clustering (Ct): This provides a relative value estimate to combat the risk of the model being confidently wrong (local optima). After generating foresight paths (rollouts) for multiple candidate steps, these paths are clustered (using TF-IDF in the main experiments, or sentence embeddings). The alignment score Ct for a step at is the normalized size of the cluster its foresight path belongs to.
Ct=#Foresight Paths∣Cluster(at)∣
Steps leading to future paths consistent with many other candidate steps receive higher alignment scores.
The final reward R combines normalized versions of the Advantage and Alignment scores, sampling from their joint distribution:
R(x,a≤t,a>t)=Norm(At)+Norm(Ct)
where Norm(v)=∑atexp(v/τv)exp(v/τv). In the implementation, τ1=τ2=0.6 and equal weighting is used.
To manage the computational cost introduced by foresight sampling and avoid "overthinking", φ-Decoding incorporates a Dynamic Pruning Strategy:
- In-Width Pruning: Before performing the computationally expensive foresight simulation for all candidate steps (generated via beam search with M beams and N rollouts per beam), this step filters out unpromising candidates. It calculates the mean (μt) and standard deviation (σt) of the initial generation probabilities st=pθ(at∣x,a<t) for all M×N candidates. Only candidates with st(i)≥μt−σt are kept for foresight simulation.
- In-Depth Pruning: This strategy enables early stopping to save computation on later, potentially easier steps. It leverages the clustering results from the Alignment Assessment. If the largest cluster contains a fraction of the foresight paths exceeding a threshold δ (e.g., δ=0.7), the algorithm stops the step-by-step foresight process and reverts to standard auto-regressive generation for the remainder of the sequence. This occurs only after a minimum number of foresight steps (Tmin) have been taken.
Implementation:
- The algorithm operates stepwise, maintaining M active beams (sequences).
- At each step t, N candidate next steps are sampled for each beam.
- In-width pruning filters these candidates.
- Foresight paths of length Tmin to Tmax tokens are generated for the remaining candidates using the LLM.
- Advantage (At) and Alignment (Ct) scores are calculated based on these paths.
- The combined score R is used to re-weight the initial probabilities, and M steps are sampled to form the beams for step t+1.
- In-depth pruning checks if early stopping is applicable.
- The process uses the vLLM engine for efficient inference on GPUs.
- Hyperparameters (M, N, K, Tmin, Tmax, δ) are tuned per task and model (see Appendix Table 4 for examples). For LLaMA3.1-8B on GSM8K, typical values are M=4,N=4,K=3,Tmin=4,Tmax=8,δ=0.7.
Algorithm Pseudocode Overview (Algorithm 1):
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
|
function phi_decoding(x, model, M, N, T_min, T_max, K, delta):
beams = initialize_beams(x)
for t = 1 to MAX_STEPS:
candidates = {}
// Step Rollout (Generate M*N candidates)
for beam in beams:
next_steps, step_probs = sample_next_steps(model, beam, N)
add_candidates(candidates, next_steps, step_probs, beam)
// In-Width Pruning
pruned_candidates = in_width_prune(candidates)
// Step Foresight & Value Estimation
step_values = {}
foresight_paths = {}
for cand_step, beam_prefix in pruned_candidates:
path, path_prob = generate_foresight(model, beam_prefix + cand_step, T_max)
foresight_paths[cand_step] = path
advantage = calculate_advantage(path_prob, previous_path_prob[beam_prefix])
step_values[cand_step] = {"advantage": advantage}
clusters = cluster_paths(foresight_paths, K)
for cand_step in step_values:
alignment = calculate_alignment(cand_step, clusters, len(foresight_paths))
step_values[cand_step]["alignment"] = alignment
combined_value = combine_scores(step_values[cand_step]["advantage"], alignment)
step_values[cand_step]["final_value"] = combined_value
// Sample M Steps for next beams
next_beams = sample_next_beams(pruned_candidates, step_values, M)
beams = next_beams
// In-Depth Pruning
if t >= T_min and check_early_stop(clusters, len(foresight_paths), delta):
final_sequence = complete_autoregressive(model, beams[0]) // Complete best beam
return final_sequence
// Fallback if max steps reached
final_sequence = complete_autoregressive(model, beams[0])
return final_sequence |
Evaluation and Results:
- Tested on GSM8K, MATH-500, GPQA, ReClor, LogiQA, ARC-C, and AIME benchmarks.
- Compared against Auto-Regressive (CoT), ToT, MCTS, Guided Decoding, and Predictive Decoding.
- Used LLaMA3.1 (8B, 70B), Mistral-v0.3-7B, Qwen2.5-3B, and R1-Distill-LLaMA-8B models.
- φ-Decoding significantly outperformed CoT (e.g., +14.6% avg on LLaMA3.1-8B) and strong baselines across benchmarks, often with lower or comparable computational cost (FLOPS).
- Showed strong inference-time scaling: performance improved consistently with increased compute budget, outperforming other methods at similar FLOPS levels (Figure 1).
- Ablation studies confirmed the positive contributions of foresight sampling, clustering, and dynamic pruning (Table 2). Pruning significantly reduced FLOPS while sometimes even improving accuracy by filtering noise.
- Demonstrated good generalization across model sizes (3B to 70B) and effectiveness even on challenging competition-level tasks like AIME, improving performance even for specialized models like DeepSeek-R1 (Table 3, Table 5, Appendix C).
- Analysis suggested its step value estimation is more accurate than baselines and correlates positively with final task performance (Figure 2).
Practical Implications:
- φ-Decoding offers a practical way to boost the reasoning performance of existing LLMs at inference time without requiring model retraining or external reward models.
- It provides a better trade-off between performance and computational cost compared to methods like MCTS or ToT, making advanced reasoning more feasible.
- The dynamic pruning allows for adaptive compute allocation, spending more resources on difficult steps and saving compute on easier ones.
- It can be implemented as a decoding strategy within existing LLM serving frameworks like vLLM.
In summary, φ-Decoding presents an effective and relatively efficient inference-time algorithm for improving LLM reasoning by combining foresight simulation, a novel step value estimation based on advantage and alignment, and dynamic pruning strategies. Its strong empirical results and scalability make it a promising technique for practical applications requiring robust reasoning.