- The paper introduces Adaptive Parallel Reasoning (APR), enabling language models to spawn and join child threads for distributed inference.
- It significantly improves accuracy and reduces latency by balancing serial and parallel computations to overcome context window limits.
- Reinforcement Learning fine-tuning further optimizes APR, guiding models to effectively increase parallel search width and resource utilization.
This paper introduces Adaptive Parallel Reasoning (APR), a framework designed to improve the reasoning capabilities of LLMs by enabling them to dynamically manage both serial and parallel computation during inference. It addresses the limitations of existing methods: serialized approaches like chain-of-thought often hit context window limits and suffer high latency, while standard parallel methods like self-consistency lack coordination and efficient resource use.
APR equips LLMs with a multi-threading mechanism using two primary operations:
spawn(msgs)
: Allows a parent inference thread to initiate multiple child threads. The msgs
argument is a list of strings, each providing the initial context for a separate child thread. This allows the parent to delegate distinct sub-tasks or exploration paths.
join(msg)
: Used by a child thread to terminate its execution and return a specific message (msg
) back to the parent thread. Child threads can selectively return information (e.g., only successful solution paths), keeping the parent's context concise.
Execution flow involves the parent thread generating tokens until it outputs a spawn()
command. Child threads then execute in parallel (leveraging batching in serving frameworks like SGLang for efficiency). Once all children call join()
, the parent thread resumes decoding, conditioned on its prior context and the messages returned by the children. This distribution of computation helps avoid exceeding context limits in any single thread and reduces overall latency compared to fully sequential methods.
Training APR Models:
Training involves a two-stage process:
- Supervised Initialization: A LLM (trained from scratch in the paper's experiments) is first trained on demonstrations generated by a symbolic solver. Unlike the standard Stream of Search (SoS) solver which produces purely sequential traces (DFS or BFS), the APR solver generates "hybrid" search traces that explicitly include
spawn()
and join()
operations, demonstrating parallel exploration. This teaches the model the syntax and basic usage of the parallel operations. The paper also introduces an improved sequential baseline solver, SoS+.
- Reinforcement Learning (RL) Fine-tuning: The supervised model is further optimized using RL (specifically GRPO). The reward signal is based on task success (e.g., finding a correct solution in the Countdown task). This end-to-end optimization allows the model to learn when and how to best utilize parallelization, balancing exploration width (number of child threads) and depth (length of individual threads) to maximize success rate without needing predefined search structures.
Experiments and Key Findings (Countdown Task):
- Model: Llama2 architecture (228M non-embedding params), 4k context window, trained from scratch.
- Baselines: SoS+ (serialized), Self-consistency (cons@n on SoS+ outputs), Pass@n.
- Metrics: Accuracy, Total Tokens (compute), Sequential Tokens (proxy for minimum latency), Wall-clock Latency.
Results:
- Context Window Efficiency: APR achieves significantly higher accuracy than SoS+ within the same context window limit. At a 4k token limit, APR (+RL) reached 83.4% accuracy compared to 60.0% for SoS+ (+RL). APR distributes the computation, preventing any single thread from easily exhausting the context.
- Compute Scaling: APR scales better with increasing computational budget (total tokens). With a budget of ~20k tokens, APR achieved 80.1% accuracy, surpassing SoS+ cons@7 (66.6%) and even SoS+ pass@8 (68.4%).
- Latency Improvement: APR offers better accuracy for a given latency. Compared to SoS+, APR achieves higher accuracy with fewer sequential tokens (the longest chain of dependent computations). In wall-clock time tests (on 8 GPUs), APR achieved 75.2% accuracy at ~5000ms latency, while SoS+ reached only 57.3%.
- RL Impact: RL significantly boosted APR's performance (from 75.5% to 83.4%). Analysis showed RL primarily taught the model to use more computation effectively (spawning more child threads, increasing search width) rather than just improving reasoning quality within a fixed budget. The gain from RL was much larger for APR (7.9% absolute) than for SoS+ (2.7%).
Implementation Considerations:
- Serving: APR relies on efficient batching of parallel child threads. The paper uses SGLang, which supports features like continuous batching and RadixAttention (prefix sharing), making parallel execution feasible and reducing overhead.
- Training Data: Requires generating specialized training data with
spawn
/join
tokens using a symbolic solver (see Algorithms 1 & 2 in the Appendix).
- Compute Control: During training/inference, compute can be controlled by conditioning the model on the number of child threads to spawn (for APR) or target context length (for SoS+).
- Symbolic Solvers: Pseudocode for the SoS+ and APR symbolic solvers is provided (Algorithms 1 & 2). These solvers use heuristics (e.g.,
is_promising
) to decide when to explore deeply or branch out, providing initial strategies for the supervised learning phase.
Limitations and Future Work:
- Experiments focused on the Countdown task and models trained from scratch. Applying APR to large pre-trained models and more general reasoning tasks is future work.
- The current approach relies on supervised pre-training with solver-generated data. Exploring direct RL fine-tuning from pre-trained models is a potential direction.
- The
spawn
/join
mechanism is a basic form of inter-thread communication. More complex protocols could be explored.
In summary, APR presents a novel method for training LLMs to adaptively manage parallel computation during reasoning. By learning to use spawn()
and join()
operations through a combination of supervised learning and RL, models can overcome context limitations, reduce latency, and scale more effectively with compute compared to purely sequential or basic parallel approaches.