Reasoning to Learn from Latent Thoughts (2503.18866v1)
Abstract: Compute scaling for LLM (LM) pretraining has outpaced the growth of human-written texts, leading to concerns that data will become the bottleneck to LM scaling. To continue scaling pretraining in this data-constrained regime, we propose that explicitly modeling and inferring the latent thoughts that underlie the text generation process can significantly improve pretraining data efficiency. Intuitively, our approach views web text as the compressed final outcome of a verbose human thought process and that the latent thoughts contain important contextual knowledge and reasoning steps that are critical to data-efficient learning. We empirically demonstrate the effectiveness of our approach through data-constrained continued pretraining for math. We first show that synthetic data approaches to inferring latent thoughts significantly improve data efficiency, outperforming training on the same amount of raw data (5.7\% $\rightarrow$ 25.4\% on MATH). Furthermore, we demonstrate latent thought inference without a strong teacher, where an LM bootstraps its own performance by using an EM algorithm to iteratively improve the capability of the trained LM and the quality of thought-augmented pretraining data. We show that a 1B LM can bootstrap its performance across at least three iterations and significantly outperform baselines trained on raw data, with increasing gains from additional inference compute when performing the E-step. The gains from inference scaling and EM iterations suggest new opportunities for scaling data-constrained pretraining.
Summary
- The paper demonstrates that augmenting LM pretraining with inferred latent thoughts can improve performance (e.g., math reasoning accuracy from 5.7% to 25.4%).
- The methodology leverages both synthetic data generation via a strong teacher model and an EM-based bootstrapping approach to extract intermediate reasoning steps.
- The research highlights that increased computational investment during the inference (E-step) directly scales model capabilities in data-constrained settings.
The paper "Reasoning to Learn from Latent Thoughts" (2503.18866) investigates methods to enhance LLM (LM) pretraining data efficiency, particularly in scenarios where high-quality human-written text is becoming a limiting factor for scaling. The central hypothesis is that observable text (e.g., web text) represents a compressed version of an underlying, more verbose human thought process. By explicitly modeling and inferring these "latent thoughts," which contain crucial contextual knowledge and reasoning steps, LMs can learn more effectively from the available data.
Methodology: Inferring and Leveraging Latent Thoughts
The core idea is to augment the pretraining data with inferred latent thoughts. The process treats the observed text x as the final output derived from a latent thought process z. The goal is to learn the model parameters θ by maximizing the likelihood of the observed data, potentially involving the latent variables: pθ(x)=∫pθ(x,z)dz. Instead of directly optimizing this intractable marginal likelihood, the paper proposes practical methods to approximate and leverage the latent thoughts z.
Two primary approaches are explored for inferring and utilizing these latent thoughts:
- Synthetic Data Generation with a Strong Teacher: This approach utilizes a powerful, pre-existing LM (like GPT-4) as a teacher model. Given an instance of observed text x (e.g., a math problem and its solution), the teacher model is prompted to generate the plausible latent thoughts z (e.g., intermediate reasoning steps, relevant theorems, failed attempts) that could have led to x. The original data instance x is then augmented with these synthetic thoughts z^ to form a richer training example (x,z^). The target LM is then trained on this augmented dataset. The intuition is that the verbose, step-by-step reasoning provided by the teacher model offers a more explicit learning signal than the compressed final text alone.
- Bootstrapping via Expectation-Maximization (EM): To circumvent the need for a powerful external teacher model, a bootstrapping approach based on the EM algorithm is proposed. This allows an LM to iteratively improve its own capabilities and the quality of the inferred thoughts. The process alternates between two steps:
- E-step (Expectation/Inference): Given the current model parameters θt, infer the posterior distribution over latent thoughts pθt(z∣x) for each data instance x in the pretraining corpus. In practice, this involves using the current LM θt to generate plausible thoughts z^ conditioned on x. This step essentially performs inference to "fill in" the missing reasoning steps based on the model's current understanding. Increasing computational resources allocated to this inference step (e.g., using more sampling, beam search, or self-consistency techniques) can potentially improve the quality of the inferred thoughts z^.
- M-step (Maximization/Training): Update the model parameters by training on the data augmented with the inferred thoughts from the E-step. Maximize the expected complete-data log-likelihood, which translates to training the model θt+1 on the dataset {(x,z^)} generated in the E-step.
This iterative process allows the model to progressively refine its understanding of the underlying reasoning processes and generate increasingly accurate latent thoughts, thereby bootstrapping its performance without external supervision beyond the initial raw text data.
Empirical Evaluation on Mathematical Reasoning
The effectiveness of these approaches was evaluated through continued pretraining on mathematical reasoning tasks, specifically using the MATH dataset. A 1B parameter LM was used as the base model.
Key findings include:
- Synthetic Data Efficacy: Training on data augmented with synthetic thoughts generated by a strong teacher (GPT-4) significantly improved performance compared to training on the same amount of raw data. Specifically, performance on the MATH dataset increased from a baseline of 5.7% (trained on raw data) to 25.4% when trained on the thought-augmented data, demonstrating a substantial gain in data efficiency. This highlights the value locked within the latent reasoning process that isn't explicit in the final text.
- Bootstrapping Performance: The EM-based bootstrapping approach successfully improved the LM's performance without relying on an external teacher. The 1B parameter LM demonstrated performance gains across at least three iterations of the EM algorithm. This iterative self-improvement significantly outperformed the baseline trained only on raw data.
- Inference Scaling: The paper observed that increasing the computational budget allocated to the inference phase (E-step) within the EM algorithm led to further performance improvements. This suggests that investing more compute in generating higher-quality latent thoughts during the E-step directly translates to better model capabilities after the M-step. This finding introduces a new scaling dimension – inference compute – which can be leveraged in data-constrained pretraining regimes.
Implementation Considerations
Implementing the latent thought methodology involves several practical considerations:
- Thought Generation (Teacher Model): When using a teacher model, careful prompt engineering is required to elicit detailed and relevant reasoning steps (thoughts). The prompt should guide the teacher model to generate step-by-step derivations, justifications, definitions, or even potential pitfalls related to the input text (e.g., math problem). The quality and cost of the teacher model are critical factors.
1 2 3 4 5 6 7 8 9 10 11 12 13
# Example prompt structure (simplified) prompt = f""" Given the following math problem and solution: Problem: {problem_text} Solution: {solution_text} Generate the detailed step-by-step thought process, including intermediate calculations, relevant formulas, and reasoning, that leads from the problem to the solution. Be verbose and explicit. Thought Process: """ # teacher_model.generate(prompt) -> yields synthetic thoughts 'z_hat'
- Thought Generation (Bootstrapping E-step): In the EM approach, the E-step requires using the current model θt to generate thoughts z^ for each x. This can be computationally intensive, especially for large datasets. Sampling strategies (e.g., temperature sampling, top-k, nucleus sampling) or more sophisticated search methods might be needed. The compute allocated here is a key hyperparameter.
1 2 3 4 5 6 7 8 9
# Simplified E-step logic augmented_data = [] for x in raw_dataset: # Generate thoughts using the current model (theta_t) prompt = f"Generate the thought process for: {x}" # Inference might involve sampling multiple thoughts and selecting the best z_hat = model_t.generate(prompt, num_samples=k, ...) best_z_hat = select_best_thought(z_hat, model_t, x) # Optional scoring/selection augmented_data.append((x, best_z_hat))
- Data Augmentation (M-step): The format for incorporating thoughts needs consideration. Should thoughts precede the original text, be interleaved, or appended? The paper implicitly suggests concatenating or structuring them such that the model learns to predict the text x given the thoughts z. The M-step involves standard LM training (e.g., fine-tuning) on this augmented dataset (x,z^).
1 2 3 4
# Simplified M-step logic # model_{t+1} = train(model_t, augmented_data) # Example training instance format: "Thought: <z_hat> Text: <x>" # Optimize model parameters using standard LM objectives (e.g., cross-entropy loss)
- EM Algorithm Management: The iterative nature requires careful management of model checkpoints and datasets across iterations. Convergence criteria or a fixed number of iterations need to be defined. Monitoring performance on a validation set across iterations is crucial.
- Computational Cost: The primary cost shift is towards inference, especially in the bootstrapping approach. The E-step can require significant compute, potentially exceeding the cost of the M-step training update, depending on the inference strategy and dataset size. However, this inference compute directly contributes to improving data efficiency.
Practical Implications and Applications
This research offers a promising direction for addressing the data scarcity challenge in LM pretraining.
- Enhanced Data Efficiency: By extracting more learning signal from existing data through inferred thoughts, this method can significantly reduce the amount of raw text needed to achieve a target performance level, especially in knowledge-intensive domains like mathematics, science, or coding, where reasoning is paramount.
- Scaling Pretraining: It introduces inference compute as a new axis for scaling LM capabilities, complementing traditional model and data scaling. When data is the bottleneck, investing compute in the E-step inference offers a way to continue improving model performance.
- Domain Specialization: The bootstrapping approach is particularly valuable for specialized domains where large, high-quality teacher models might not be available or suitable. It allows models to improve using only the domain-specific raw text.
- Interpretability and Reasoning: While not the primary focus, the generated thoughts could potentially offer insights into the model's reasoning process, although the faithfulness of these inferred thoughts requires further investigation.
The trade-off between using a strong teacher (simpler setup, potentially higher initial quality thoughts, but requires a capable teacher) and bootstrapping (no teacher dependency, potential for continuous self-improvement, but more complex iterative process and sensitive to initial model quality) depends on the specific application constraints and available resources.
Conclusion
The work on learning from latent thoughts presents a compelling approach to improve data efficiency in LM pretraining by explicitly modeling the reasoning process underlying text generation. The empirical results, particularly the significant gains on the MATH dataset using both teacher-based synthesis and EM-based bootstrapping, validate the core hypothesis. The demonstration that performance can be scaled by increasing inference compute during the E-step opens up new avenues for scaling LMs in data-constrained environments, suggesting that inferential reasoning can partially substitute for raw data quantity. This methodology holds considerable potential for training more capable models, especially in specialized domains requiring complex reasoning, even when faced with limited high-quality training data.