Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
119 tokens/sec
GPT-4o
56 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Think, Prune, Train, Improve: Scaling Reasoning without Scaling Models (2504.18116v1)

Published 25 Apr 2025 in cs.LG

Abstract: LLMs have demonstrated strong capabilities in programming and mathematical reasoning tasks, but are constrained by limited high-quality training data. Synthetic data can be leveraged to enhance fine-tuning outcomes, but several factors influence this process, including model size, synthetic data volume, pruning strategy, and number of fine-tuning rounds. We explore these axes and investigate which conditions enable model self-improvement. We introduce the Think, Prune, Train process, a scalable framework that iteratively fine-tunes models on their own reasoning traces, using ground-truth pruning to ensure high-quality training data. This approach yields improved performance: on GSM8K, Gemma2-2B achieves a Pass@1 of 57.6% (from 41.9%), Gemma2-9B reaches 82%, matching LLaMA-3.1-70B, and LLaMA-3.1-70B attains 91%, even surpassing GPT-4o, demonstrating the effectiveness of self-generated reasoning and systematic data selection for improving LLM capabilities.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (4)
  1. Caia Costello (1 paper)
  2. Simon Guo (3 papers)
  3. Anna Goldie (19 papers)
  4. Azalia Mirhoseini (40 papers)

Summary

This paper introduces the Think, Prune, Train (TPT) framework, an iterative process designed to improve the reasoning capabilities of LLMs using self-generated data, without requiring larger teacher models, complex reinforcement learning, or extensive external datasets. The core challenge addressed is the diminishing returns from training LLMs on public text and the risks (like model collapse) associated with recursive fine-tuning on unfiltered self-generated data.

The TPT framework operates in a loop consisting of three main steps:

  1. Think: The current model is prompted to generate step-by-step reasoning solutions for problems using a structured approach like Chain-of-Thought (CoT). The paper describes using a temperature of 0.8 and generating multiple solutions per problem (e.g., 10 solutions per problem for the training set).
  2. Prune: The generated solutions are filtered to select only the correct ones. Crucially, this pruning relies on ground-truth correctness filtering, meaning the output is programmatically checked against the known correct answer for the problem. This step is highlighted as essential for preventing model collapse seen in prior self-training attempts on unfiltered data. The paper experiments with strict ground truth matching and notes that unpruned or softly pruned data generally leads to worse performance compared to strictly pruned data.
  3. Train: The current model is fine-tuned via supervised fine-tuning (SFT) on the selected, correct self-generated solutions. This process is repeated iteratively, with each new model generating the data for the next training round (Algorithm 1). The paper investigates different data retention strategies and finds that replacing the dataset entirely with newly generated, pruned data each round is effective for isolating the effects of iterative refinement beyond simple data accumulation.

The paper contrasts TPT with related work, including distillation (which relies on larger teacher models), other self-improvement methods (like LLaMA 3.1's pipeline involving DPO and reward models, or STaR which doesn't iteratively fine-tune the improved model), and various RL-based approaches (like ReST, ReST-EM, DeepSeek R1) that typically use reward models and explicit policy optimization. TPT is simpler, focusing solely on SFT guided by correctness pruning. The paper also offers a theoretical perspective, suggesting that SFT on correctness-filtered data can be seen as an implicit form of policy gradient optimization, where correctness serves as a reward signal.

Experiments are conducted primarily on mathematical reasoning (GSM8K (Cobbe et al., 2021 )) and code generation (CodeContests [Li_2022cc]) tasks using instruction-tuned variants of Gemma (2B, 9B) and LLaMA (1B, 70B) models.

Key findings from the experiments include:

  • Data Scaling Alone is Insufficient: Simply increasing the volume of synthetic data (even when generated by larger models) does not guarantee sustained performance gains. Performance can plateau or even decline beyond a certain point, indicating that data quality and filtering are more critical than sheer quantity (Tables \ref{tab:gsm8k_math_performance}, \ref{tab:code_contest_scaling}, \ref{tab:leetcode_scaling}).
  • Recursive TPT Improves Reasoning: Iterative application of the TPT process leads to significant performance improvements, particularly in first-attempt accuracy (Pass@1).
    • On GSM8K, Gemma2-2B improves from 41.9% to 57.6% Pass@1 over four iterations. Gemma2-9B reaches 82% Pass@1 after three iterations, surpassing LLaMA3.1-70B-Instruct's baseline of 78%. LLaMA3.1-70B further improves its Pass@1 from 78.6% to 91.5% after just one round of TPT, exceeding GPT-4o's reported performance (Tables \ref{tab:math_recursive_performance}, \ref{tab:llama_gsm8k}, Figure \ref{fig:pass1_gsm8k}).
    • On CodeContests, Gemma2-2B improves from 0.9% to 1.14% Pass@1 and Gemma2-9B from 5.1% to 7.9% Pass@1 over multiple iterations (Table \ref{tab:code_contest_synthetic}).
  • Pass@k Behavior: While Pass@1 consistently improves with recursive training, Pass@k (for k > 1, e.g., Pass@20, Pass@50) tends to plateau after the first few iterations (Figure \ref{fig:performance_comparison_gsm8krecur}, Tables \ref{tab:math_recursive_performance}, \ref{tab:code_contest_synthetic}). This observation suggests a potential effect similar to mode collapse, where the model becomes better at producing a few correct solutions with high confidence but does not necessarily increase the overall diversity of correct outputs. The authors argue this is acceptable for tasks where correctness is paramount.
  • Pruning is Crucial: Training on unpruned self-generated data results in performance degradation compared to the baseline, confirming that correctness-based filtering is vital for stable self-improvement (Table \ref{tab:noprune}).
  • Model Size Matters: Larger models like LLaMA-70B achieve higher absolute performance and converge faster with TPT (requiring fewer iterations) compared to smaller models like Gemma2-2B, highlighting the interplay between base model capacity and the effectiveness of the self-improvement process.
  • Mixing Data: Initial experiments show that mixing real and synthetic data can improve performance initially, but attempts to recursively fine-tune on mixed datasets failed, suggesting potential instability in that approach (Table \ref{tab:mixed_dataset_ablation}).

Implementation considerations include:

  • The need for a reliable ground-truth verification mechanism for the target task. This makes TPT particularly well-suited for tasks like math and code where correctness can be automatically checked.
  • Careful tuning of SFT hyperparameters (learning rate, optimizer, epochs) to maintain training stability, especially across multiple recursive rounds.
  • Managing the dataset size per round (n|n| in Algorithm \ref{recursive-algorithm}) based on task complexity and model size. The paper uses 2000 examples for GSM8K and 1000 for CodeContests per round.
  • The computational cost involves generating solutions, verifying them, and performing SFT in each recursive loop. Generating multiple solutions per problem (e.g., 10 in the paper) increases the pool for pruning but also increases inference cost.

In conclusion, the paper demonstrates that the simple Think, Prune, Train framework, relying on structured prompting, correctness-based pruning, and iterative SFT on self-generated data, is an effective method for improving LLM reasoning capabilities without scaling model size or relying on complex RL or external teachers. The significant gains in Pass@1 performance, even surpassing strong benchmarks, highlight the potential of systematic data selection for self-improvement, particularly in domains where correctness is verifiable.