ReSTEM: EM Self-Training for LLMs
- ReST⁽ᴱᴹ⁾ is a self-training methodology that applies an Expectation-Maximization framework to iteratively refine language models using binary reward feedback.
- Empirical results show significant gains in mathematical reasoning and program synthesis, with improvements of up to 6% over supervised fine-tuning.
- The approach leverages reward-verified training data and addresses challenges like sampling efficiency and overfitting through controlled EM iterations.
ReST is a self-training methodology for LLMs that utilizes an expectation-maximization (EM) approach to enhance model performance on tasks where automated scalar feedback—often binary correctness—is available. Rather than relying exclusively on human-generated supervision, ReST enables models to iteratively refine themselves by generating candidate outputs, filtering them using reward signals (typically correctness-based), and fine-tuning on these filtered outputs. It is situated in the context of reinforcement learning (RL) with language, drawing mathematical underpinnings both from EM theory and evidence lower bound (ELBO) maximization. ReST has shown strong empirical improvements on advanced mathematical reasoning and program synthesis tasks, offering a scalable pathway to improve capabilities beyond human-annotated datasets (Singh et al., 2023).
1. Formalism and Algorithmic Structure
ReST alternates between data generation and model improvement, corresponding to the classic EM algorithm’s E-step and M-step adapted to the RL fine-tuning context.
- E-step (Data Generation and Filtering):
- For each prompt , the current model generates samples using temperature or top- sampling.
- Each output is scored by a reward function , typically binary (1 for correctness, 0 otherwise).
- Only those samples with are retained as positive examples for further training.
- M-step (Reward-Weighted Fine-Tuning):
- The model is fine-tuned (often from the base pretrained checkpoint) on the pool of positive samples using a reward-weighted log-likelihood loss:
- This is equivalent to minimizing the divergence between a target distribution and the current model .
- ELBO-Based Justification:
The procedure optimizes a lower bound on , where indicates observing a high-reward output under . Formally,
Applying Jensen's inequality yields an ELBO whose maximization motivates the E/M updates.
2. Empirical Performance and Scaling
On both the MATH benchmark (complex mathematical reasoning) and APPS (code synthesis), ReST demonstrates significant gains over traditional supervised fine-tuning (SFT):
- MATH Dataset:
Test accuracy improves by approximately 5.94% (PaLM 2-S) and 6.34% (PaLM 2-L) over SFT.
- APPS Dataset:
Highest pass@1 performance among compared methods, confirming the robustness of reward filtering, though incremental gains occur primarily in the first EM iteration for this coding domain.
Scaling trends indicate that larger LLMs not only benefit more from ReST in absolute performance but also enjoy increased relative improvements compared to SFT on human data alone.
Overview of Results for Advanced Reasoning Tasks
Benchmark | Model | SFT Accuracy | ReST Accuracy | Gain |
---|---|---|---|---|
MATH | PaLM 2-S | n/a | +5.94% | +5.94% |
MATH | PaLM 2-L | n/a | +6.34% | +6.34% |
APPS | PaLM 2-S/L | baseline | Improved pass@k/majority | notable |
All statistics as reported in the original paper (Singh et al., 2023).
3. Mathematical and Conceptual Connections
ReST generalizes classical EM approaches for inference with missing data, viewing the process in information geometric terms as alternating projections between policy distributions and reward-filtered empirical distributions (Hino et al., 2022). The algorithm maximizes a reward-weighted expected log-likelihood, equivalent, for deterministic binary rewards, to optimizing over the subdistribution of self-generated, correct outputs.
The method is allied to RL, Reward-Weighted Regression (RWR), and expert iteration, but differs by:
- Directly filtering for correct outputs as the E-step,
- Avoiding gradient-based policy optimization or value-based credit assignment for intermediate steps,
- For tasks like code synthesis and mathematical reasoning, using reward functions with verifiable correctness (unit tests, solution checkers).
4. Model Selection, Regularization, and Overfitting
Observed empirical behavior indicates that excessive EM iterations can lead to overfitting, especially on datasets with limited task diversity (e.g., APPS). To mitigate this, ReST restarts fine-tuning from the base pretrained checkpoint for each iteration and limits the number of EM cycles—drawing on the observation that large performance gains often accrue in early rounds, with diminishing or negative returns if iterations continue. A plausible implication is that adaptive early stopping and dynamic sampling strategies could further stabilize training.
5. Broader Applications and Implications
While the core experiments evaluate mathematical reasoning and code generation tasks, the formulation is applicable to any domain where reward-feedback, particularly binary correctness, is available and high-quality annotated data is expensive or limited. Potential domains include machine translation (where automatic metrics approximate reward), semantic parsing, logical reasoning, and preference tuning.
The methodology enables:
- Substantial reduction in dependence on human-annotated data.
- Generation of diverse, reward-verified training data (multiple correct outputs for a single context).
- Improved transferability: Models fine-tuned with ReST demonstrate robust performance not only on the source benchmark but also on out-of-domain tasks (e.g., GSM8K, Big-Bench Hard).
6. Limitations and Future Directions
- Sampling Limitations:
Current instantiations use temperature or top- sampling in the E-step, which may not efficiently cover the solution space. As suggested in the source (Singh et al., 2023), more sophisticated sampling/search methods—such as tree search or Monte Carlo search—may provide higher quality or more diverse candidate traces.
- Reward Specification:
The method relies on externally provided binary rewards (e.g., automated correctness checks). Extensions might incorporate learned reward models or more nuanced feedback mechanisms for tasks lacking strict correctness criteria.
- Process-Level Evaluation and Process Reward:
ReST only filters whole traces by their final outcome. Subsequent work (e.g., ReST-MCTS* (Zhang et al., 6 Jun 2024)) augments this approach by assigning per-step rewards using tree search and value modeling, thus improving the granularity of supervision and further mitigating the risk of credit misassignment due to correct but flawed reasoning chains.
- Stability and Generalization:
Overfitting is possible when the dataset of high-reward outputs becomes small or unrepresentative. Regularization and dynamically controlled dataset curation are areas of recommended future exploration.
7. Relationship to Contemporary Methods
ReST shares conceptual territory with self-improving RL algorithms such as expert iteration and RWR, but is distinguished by its EM-style alternation and its reliance on reward-based filtering rather than policy gradient or explicit value models. Empirical comparisons indicate strong gains over pure supervised finetuning approaches and competitive or superior performance compared to other self-rewarding/self-training pipelines on the considered benchmarks.
Advancements beyond ReST, such as ReST-MCTS* (Zhang et al., 6 Jun 2024), integrate process-level supervision and mutual policy–value self-training, but fundamentally build on the alternating EM-based self-supervision paradigm established by ReST.
In sum, ReST operationalizes a principled EM approach for LLM self-training with scalar feedback, producing substantial improvements in problem-solving domains and providing a blueprint for scaling model performance beyond the limitations of human-annotated data (Singh et al., 2023).