- The paper demonstrates that training with synthetic data from weaker models can yield reasoning performance gains of up to 6% compared to stronger, expensive models.
- The methodology evaluates three finetuning strategies, analyzing trade-offs in coverage, diversity, and false positive rates across different sampling budgets.
- Improved generalization on MATH and GSM-8K datasets underscores the potential of compute-efficient training for advanced LLM reasoners.
A Computationally Optimal Perspective on Training LLM Reasoners
Introduction
The paper "Smaller, Weaker, Yet Better: Training LLM Reasoners via Compute-Optimal Sampling" by Hritik Bansal, Arian Hosseini, Rishabh Agarwal, Vinh Q. Tran, and Mehran Kazemi challenges the prevailing assumption that leveraging stronger but more expensive (SE) LLMs (LMs) for synthetic data generation is the optimal strategy for improving LLM reasoning capabilities. Instead, the authors propose a compute-optimal approach by using weaker but cheaper (WC) models.
Methodological Overview
The paper investigates the trade-offs between generating synthetic data from an SE model versus a WC model under a fixed computational budget, measured in floating-point operations (FLOPs). The evaluation criteria include coverage, diversity, and false positive rate (FPR) of the synthetic data. The authors also present three finetuning paradigms:
- Knowledge distillation from a teacher model.
- Self-improvement where an LM learns from its own generated data.
- Weak-to-strong improvement (W2S-I) where a weaker LM enhances the reasoning of a stronger LM.
Experimental Setup
The empirical evaluation utilizes Gemma2 models of different sizes (9B and 27B parameters) to generate synthetic data for the MATH and GSM-8K datasets. Models are finetuned under various setups, followed by an evaluation of their reasoning capabilities. This multi-dimensional evaluation investigates coverage, diversity, and FPR at both low and high sampling budgets. Comparisons are also drawn using state-of-the-art Gemini-1.5 models, specifically examining cost efficiency in synthetic data generation.
Key Findings
Synthetic Data Quality
- Coverage: Data from the Gemma2-9B model (WC) showed higher coverage compared to Gemma2-27B (SE), with relative improvements of up to 11% for MATH and 8% for GSM-8K datasets.
- Diversity: Gemma2-9B demonstrated significantly higher diversity in generated solutions, particularly noticeable at higher sampling budgets.
- FPR: The WC model exhibited higher FPRs, which are partially offset by more robust human and automatic evaluations.
Finetuning Results
- Student-LM Finetuning: Finetuning Gemma-7B with WC-generated data yielded performance gains of up to 6%.
- WC-LM Finetuning: Gemma2-9B finetuned with its own synthetic data outperformed SE-generated data by up to 3.8%.
- SE-LM Finetuning: Remarkably, Gemma2-27B finetuned on data from Gemma2-9B (W2S-I) outperformed self-improvement results by up to 5.8%.
Generalization Capabilities
The models trained on WC data exhibited stronger generalization on the Functional MATH dataset compared to those trained on SE data. Relative gains ranged from 2% to 6.5% across various settings.
Discussion and Implications
The empirical evidence provided challenges the conventional wisdom of relying predominantly on SE models for synthetic data generation. The superior performance and generalization of LMs trained with WC-generated data indicate the potential for more compute-efficient training methodologies. This is increasingly relevant considering the trend of narrowing performance gaps between small and large LMs, as inferred from an analysis of LMs released in the past year.
Future Perspectives
Considering the rapid improvements in the capabilities of smaller LMs, the authors speculate that training with data from these models could become increasingly advantageous. As the performance gaps narrow, the cost and efficiency benefits of sampling from WC models will likely become more prominent, fostering a paradigm shift in training strategies for advanced LM reasoners.
Conclusion
This paper provides robust, empirical backing for a compute-optimal approach to LLM training. By strategically leveraging synthetic data from WC models, the research suggests avenues to enhance both the efficiency and efficacy of training protocols for reasoning tasks, laying groundwork for future advancements in the field.