Scaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach (2502.05171v2)
Abstract: We study a novel LLM architecture that is capable of scaling test-time computation by implicitly reasoning in latent space. Our model works by iterating a recurrent block, thereby unrolling to arbitrary depth at test-time. This stands in contrast to mainstream reasoning models that scale up compute by producing more tokens. Unlike approaches based on chain-of-thought, our approach does not require any specialized training data, can work with small context windows, and can capture types of reasoning that are not easily represented in words. We scale a proof-of-concept model to 3.5 billion parameters and 800 billion tokens. We show that the resulting model can improve its performance on reasoning benchmarks, sometimes dramatically, up to a computation load equivalent to 50 billion parameters.
Summary
- The paper introduces a novel architecture that uses a latent recurrent block to iteratively scale compute at test time without needing chain-of-thought token generation.
- The paper demonstrates that a 3.5B-parameter model trained on 800B tokens can match the compute load and performance of a 50B-parameter model on reasoning benchmarks.
- The paper highlights that adaptive per-token compute and efficient KV-cache sharing enable improved non-verbal reasoning and effective text generation without specialized training data.
The paper introduces a novel LLM architecture employing latent reasoning via a recurrent block iterated to arbitrary depth at test time. This approach contrasts with chain-of-thought methods that scale computation by producing more tokens and necessitates specialized training data. The proposed architecture, trained without such specialized data, is designed to capture reasoning not easily expressed verbally, leveraging small context windows.
The authors scale a proof-of-concept model to 3.5B parameters, pre-trained on 800B tokens. Performance improvements on reasoning benchmarks are demonstrated, scaling up to an equivalent computation load of a 50B parameter model. The abstract provides links to the model on Hugging Face and the code/data on GitHub.
The paper posits that current methods are wasteful because expensive internal reasoning must be verbalized in the next token. The authors suggest that models could be more competent if they natively "think" in their continuous latent space, enabling indefinitely carried computations. The authors note that this idea has been rediscovered across various machine learning domains, including recurrent neural networks, diffusion models, and looped transformers.
The benefits of depth-recurrent LLMs are shown through effective learning, efficient training, and significant performance improvements via test-time compute scaling. The architecture relies on a latent depth-recurrent block iterated a randomly sampled number of times during training. The results show that the model competes with larger open-source models using more parameters and training data. The paper highlights that recurrent depth models naturally support per-token adaptive compute, (self)-speculative decoding, and KV-cache sharing at inference time, features that typically demand substantial tuning in non-recurrent models. The paper tracks token trajectories in latent space, revealing emergent computation behaviors like shape rotations for numerical computations.
The advantages of latent recurrent thinking over long-context reasoning are:
- No need for bespoke training data. Chain-of-thought reasoning requires training on long demonstrations specific to the domain, whereas latent reasoning models can train with a variable compute budget using standard training data.
- Less memory required for training and inference compared to chain-of-thought reasoning models.
- Increased FLOPs per parameter, reducing communication costs between accelerators.
- An architectural prior towards solving problems via meta-strategies, logic, and abstraction, instead of memorization.
The paper suggests that latent reasoning could capture non-verbal facets of human reasoning, like spatial thinking, physical intuition, and motor planning. The authors argue that scaling compute via latent reasoning is a third axis to scale model performance, complementing verbalized inference scaling and parameter count scaling in pretraining.
The table of contents outlines the following sections:
- Introduction of the latent recurrent-depth model architecture and training objective.
- Description of data selection and engineering for large-scale training.
- Benchmark results showing model improvements with scaled inference compute.
- Application examples demonstrating simplified LLM use cases with recurrent models.
- Visualization of computation patterns emerging at scale, showing context-dependent behaviors in latent space like "orbiting" for numerical reasoning prompts.
The architecture is structured around decoder-only transformer blocks, divided into three functional groups:
- Prelude P: Embeds the input data into a latent space.
- Core Recurrent Block R: The central unit of recurrent computation, modifying states s∈Rn×h.
- n is the sequence dimension
- h is the hidden dimension of the model
- Coda C: Un-embeds from latent space and contains the prediction head.
Given r recurrent iterations and input tokens x∈Vn (where V is its vocabulary), these groups produce output probabilities p∈Rn×∣V∣ as follows:
e=P(x)
s0∼N(0,σ2In⋅h)
si=R(e,si−1) for i∈{1,…,r}
p=R(sr)
- σ is a standard deviation for initializing the random state.
- e represents the embedded input.
- s0 is the initial random state.
- si is the latent state at iteration i.
- p represents the probabilities of the next token.
This architecture is based on the deep thinking literature, where injecting the latent inputs e in every step and initializing the latent vector with a random state stabilizes the recurrence.
The recurrent design facilitates stable iterative operators, drawing an analogy to gradient descent where repeated data injection is required for optimization. The structure of using several layers to embed input tokens into a hidden latent space is based on empirical results analyzing standard fixed-depth transformers. The initial and end layers of LLMs are different, while middle layers are interchangeable.
The authors address the similarity of their iterative architecture to diffusion models, particularly latent diffusion models. Ablations with schemes similar to diffusion models, such as si=R(e,si−1)+n where n∼N(0,σiIn⋅h), did not show improvements. Similarly, si=Ri(e,si−1), where the core block takes the current step as input, interacted negatively with path independence.
Within each group, the model follows standard transformer layer design. Each block contains multiple layers, and each layer contains a standard, causal self-attention block using RoPE with a base of $50000$, and a gated SiLU MLP. RMSNorm is used as the normalization function, and the model has learnable biases on queries and keys. A "sandwich" layer format stabilizes the recurrence, using norm layers ni:
xl^=n2(xl−1+Attn(n1(xl−1)))
xl=n4(xl^+MLP(n3(xl^)))
Given an embedding matrix E and embedding scale γ, the prelude block first embeds input tokens x as γE(x), and then applies lP prelude layers. The core recurrent block R starts with an adapter matrix A:R2h→Rh mapping the concatenation of si and e into the hidden dimension h. This is then fed into lR transformer layers. The coda contains lC layers, normalization by nc, and projection into the vocabulary using tied embeddings ET.
The architecture is summarized by the triplet (lP,lR,lC), describing the number of layers in each stage, and by the number of recurrences r. Small-scale models with shape (1,4,1) and hidden size h=1024 are trained, along with a larger model with shape (2,4,2) and h=5280.
To ensure the model functions when scaling up recurrent iterations at test-time, iteration counts are randomly sampled during training, assigning a random number of iterations r to every input sequence. The loss function is:
L(θ)=Ex∈XEr∼ΛL(mθ(x,r),x′)
- m is the model output
- x′ is the sequence x shifted left.
- Λ is a log-normal Poisson distribution.
Given a targeted mean recurrence rˉ+1 and a variance σ=21, samples are drawn via:
τ∼N(log(rˉ)−21σ2,σ)
r∼P(eτ)+1
- N is the normal distribution
- P is the Poisson distribution
To keep computation and memory low at train time, backpropagation occurs only through the last k iterations of the recurrent unit, with k=8 in the main experiments.
The training setup includes architecture details, optimization setup, and pretraining data. All training data, pretraining code, and intermediate model checkpoints are publicly released. The dataset mixture is skewed towards code and mathematical reasoning data, combined with general webtext. Instruction data is directly mixed into the pretraining data. A vocabulary of $65536$ tokens is constructed via BPE, trained directly on the instruction data split of the pretraining corpus. Documents are packed into sequences of length 4096, discarding document ends to address the "grounding problem".
The model's layers are set to (2,4,2), and trained with a mean recurrence value of rˉ=32. The hidden size is scaled to h=5280, yielding $55$ heads of size $96$. The MLP inner dimension is $17920$, and the RMSNorm ε is 10−6. The model has approximately $1.5$B parameters in non-recurrent prelude and head, $1.5$B parameters in the core recurrent block, and $0.5$B in the tied input embedding. The initialization scheme of \citet{takase_spike_2024-1} is used, with variances of σh2=5h2 and σout2=5hl1 for the hidden and out-projection layers, respectively.
Locked-step sampling is used to enable synchronization between parallel workers, sampling a single depth r for each micro-batch. The Adam optimizer with decoupled weight regularization (β1=0.9, β2=0.95, $\eta=\num{5e-4}$) is used, modified to include update clipping and removal of the ε constant. Gradients are clipped above $1$, and the training uses warm-up and a constant learning rate.
The model is trained using compute time on the Oak Ridge National Lab's Frontier cluster, with 8 x AMD MI250X GPU nodes connected via 4xHPE Slingshot NICs. The system uses SLURM. Training is performed in bfloat16 mixed precision using a PyTorch-based implementation. Each MI250X chip achieves 192 TFLOP, with a measured maximum of 125 TFLOP/s for a single matrix multiplication. The implementation achieves a single-node training speed of 108.75 TFLOP/s, and the largest model is trained using only data parallelism with optimizer sharding and gradient checkpointing. The global batch size is 16M tokens per step. When running on 4096 GPUs, the model achieves 52-64 TFLOP/s per GPU, i.e., 1-1.2M tokens per second. A hand-crafted distributed data parallel implementation was written to circumvent an AMD interconnect issue. The training proceeded through 21 segments, and a baseline comparison was run with the same architecture in a feedforward manner. The main model was trained for 795B tokens.
At small scales, most normalization strategies and initializations work. However, at larger scales, specific configurations are required. An initial training run with parameter-free RMSNorm layers, no embedding scale γ, and a parameter-free adapter A(s,e)=s+e stalled due to the model's representation collapsing. Token correlation in the hidden states went to 1.0, indicating that the model predicted the same hidden state for every token. This issue was addressed by introducing the embedding scale factor, switching back to a conventional pre-normalization block, and switching to the learned adapter. In the final run, the issues were resolved by reverting to the sandwich block format and reducing the peak learning rate to $\num{4e-5}$.
The final model, Huginn-0125, was trained for 800B tokens, and a non-recurrent baseline was trained for 180B tokens. Benchmarks were performed against other open-source models trained on public datasets of similar size, including Amber, Pythia, and various OLMo variants. Standard benchmarks were executed through the lm-eval harness, and code benchmarks were performed via bigcode-bench. The model has 3.5B parameters but consumes raw FLOPs similar to a 32B parameter transformer during pretraining and can continuously improve with test-time scaling.
The model outperforms the older Pythia series and is comparable to the first OLMo generation, OLMo-7B, in most metrics, but lags behind later OLMo models trained on larger datasets. The authors suggest this is promising for the first recurrent-depth model for language trained at this scale. The authors collected results for established benchmark tasks and showed all models side-by-side.
For math and coding, the model was evaluated on GSM8k (as zero-shot and in the 8-way CoT setup), MATH (with the Minerva evaluation rules), and MathQA. For coding, MBPP and HumanEval were used. The model surpasses all models except the latest OLMo-2 model in mathematical reasoning. On coding benchmarks, the model beats all other general-purpose open-source models but does not outperform dedicated code models like StarCoder2. Code and mathematical reasoning continue to improve steadily throughout training.
The recurrent model was compared against its non-recurrent twin, trained to 180B tokens in the same setting. The recurrent model outperforms the baseline with a pronounced advantage on harder tasks and shows gains on GSM8k. The recurrent model, when evaluated with only a single recurrence, effectively stops improving between the early 180B checkpoint and the 800B checkpoint. The model also has improvements as a function of test-time compute, and saturation is highly task-dependent.
ARC-C performance was evaluated as a function of recurrence and number of few-shot examples. Without few-shot examples, the model saturates in compute around 8-12 iterations, but with more context, the model can reason about more information, saturating around 20-32 iterations. If OBQA is re-evaluated by providing a relevant fact, the recurrent model improves significantly, closing the gap to OLMo-2.
Due to the constant learning rate, weight averaging can materialize further improvements, simulating the result of a cooldown. An exponential moving average was used, incorporating the last 75 checkpoints with a dilation factor of 7.
Recurrent-depth models are natural tools to support methods that require substantial effort with standard transformers. The model can vary compute on a per-query level, but it would be more efficient to stop recurring early when predictions are easy. A simple exit criterion to evaluate convergence is the KL-divergence between two successive steps, and if this divergence falls below \num{5e-4}, the model stops iterating. The number of steps required to exit differs notably between categories, and experiments on MTBench show that this adaptivity does not significantly impact performance.
The authors observe that a concern with token-wise early exits for models with self-attention is that it breaks KV-caching. They attend to the last, deepest available KV states in the cache. They share KV-caches in the model with minimal impact to performance. They set a fixed KV-cache budget for the recurrence at every token k, and at iteration i, read and write the cache entry imodk. On MTBench, this does not reduce performance. Instead of sampling a random initial state s0 at every generation step, the model can warm-start with the last state sr from the previous token, reducing the average number of steps required to converge. Recurrent-depth models can inherently generate text more efficiently by using speculative decoding without the need for a separate draft model. The model can naturally be run with fewer iterations to draft the next N tokens in the generated sequence, which can then be verified with any desired number of iterations M>N later. Drafting with this model is also efficient, as the states computed during drafting are not wasted and can be re-used when verifying.
The authors analyze the trajectories {si}i=1r of the model on qualitative examples to understand what the model does while recurring in latent space. The norm distance ∣∣si−s∗∣∣ between each si in a trajectory and an approximate limit point s∗ computed with 128 iterations is analyzed. The model may trace out complicated orbits in its latent trajectory while processing information. The context dependence can also be seen in the different behavior among the three identical tokens. Many tokens simply converge to a fixed point, but for harder questions, the state of the token quickly falls into an orbit pattern. Some tokens are encoded as "sliders," where the trajectory noticeably drifts in a single direction, implementing a mechanism to count how many iterations have occurred. The emergence of structured trajectories in latent space provides insight into how the model performs computations. The models maintain path independence, and when re-initializing from multiple starting points s0, the model moves in similar trajectories.
The extent to which recurrence is a foundational concept of machine learning is hard to overstate. For transformers, recurrence was applied in \citet{dehghani_universal_2019}, who highlight the aim of recurrent depth to model universal, i.e. Turing-complete, machines. It was used at scale (but with fixed recurrence) in \citet{lan_albert_2019} and an interesting recent improvement in this line of work are described in \citet{tan_sparse_2023,abnar_adaptivity_2023} and \citet{csordas_moeut_2024}. \citet{schwarzschild_can_2021,bansal_end--end_2022,bear_rethinking_2024,mcleish_transformers_2024} show that depth recurrence is advantageous when learning generalizable algorithms when training with randomized unrolling and input injections. Recent work has described depth-recurrent, looped, transformers and studied their potential benefits with careful theoretical and small-scale analysis \citep{giannou_looped_2023,gatmiry_can_2024,yang_looped_2024,fan_looped_2025}. From another angle, these models can be described as neural networks learning a fixed-point iteration, as studied in deep equilibrium models \citep{bai_deep_2019,bai_neural_2022}. They are further related to diffusion models \citep{song_generative_2019-1}, especially latent diffusion models \citep{rombach_high-resolution_2022}, but it is noted that language diffusion models are usually run with a per-sequence, instead of a per-token, iteration count \citep{lee_deterministic_2018}. A key difference of this approach to both equilibrium models and diffusion models is in the training objective, where equilibrium methods solve the “direct” problem \citep{geiping_parametric_2019-1}, diffusion models solve a surrogate training objective, and truncated unrolling is suggested as a scalable alternative. Architectures that recur in depth can also be understood as directly learning the analog to the gradient of a latent energy-based model \citep{lecun_loss_2005,lecun_path_2022}, or to an implicitly defined intermediate layer \citep{amos_optnet:_2017}. These analogies to gradient descent at inference time also show the connection to test time adaptation \citep{sun_test-time_2020}, especially test-time adaptation of output states \citep{boudiaf_parameter-free_2022}.
In terms of future work, there are potentially a large number of novel post-training schemes that further enhance the capabilities of these models, such as fine-tuning to compress the recurrence or reinforcement learning with data with different hardness levels, or to internalize reasoning from CoT data into the recurrence. Another aspect not covered is the relationship to other modern architecture improvements. With recurrent depth, blocks containing linear operators can repeat until all necessary comparisons between sequence elements are computed. For simplicity, the authors focus on a single recurrence, where prior work has considered multiple successive recurrent stages. Finally, the proposed architecture is set up to be compute-heavy, with more "materialized" parameters than actual parameters, which naturally mirrors mixture-of-expert models (MoE). In a standard MoE model, each expert can only be activated once per forward pass, or skipped entirely, but a recurrent MoE model could also refine its latent state over multiple iterations, routing to the same expert multiple times, before switching to a different one.
The models described are still a proof-of-concept, but the authors describe how to train a latent recurrent-depth architecture, what parameters were chosen, and then trained a single model at scale. Still, a number of interesting behaviors emerge naturally from recurrent training, especially the ability to use latent reasoning to dramatically improve performance on reasoning tasks by expending test-time computation.
Related Papers
Tweets
YouTube
HackerNews
- Scaling up test-time compute with latent reasoning: A recurrent depth approach (148 points, 43 comments)
- Scaling Up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach (1 point, 0 comments)
- New paper gives models a chance to think in latent space before outputting tokens, weights are already on HF - Scaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach (445 points, 60 comments)
- [R] Scaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach (47 points, 4 comments)
- Scaling up test-time compute with latent reasoning: A recurrent depth approach (3 points, 1 comment)