Entropy-Based Adaptive Weighting for Self-Training (2503.23913v1)
Abstract: The mathematical problem-solving capabilities of LLMs have become a focal point of research, with growing interests in leveraging self-generated reasoning paths as a promising way to refine and enhance these models. These paths capture step-by-step logical processes while requiring only the correct answer for supervision. The self-training method has been shown to be effective in reasoning tasks while eliminating the need for external models and manual annotations. However, optimizing the use of self-generated data for model training remains an open challenge. In this work, we propose Entropy-Based Adaptive Weighting for Self-Training (EAST), an adaptive weighting strategy designed to prioritize uncertain data during self-training. Specifically, EAST employs a mapping function with a tunable parameter that controls the sharpness of the weighting, assigning higher weights to data where the model exhibits greater uncertainty. This approach guides the model to focus on more informative and challenging examples, thereby enhancing its reasoning ability. We evaluate our approach on GSM8K and MATH benchmarks. Empirical results show that, while the vanilla method yields virtually no improvement (0%) on MATH, EAST achieves around a 1% gain over backbone model. On GSM8K, EAST attains a further 1-2% performance boost compared to the vanilla method.
Summary
- The paper presents EAST, an innovative method that leverages model entropy to prioritize uncertain yet correct training samples.
- It computes entropy over generated reasoning paths to assign higher weights to more challenging examples, improving the learning process.
- Experimental results reveal 1-2% gains on GSM8K and notable improvements on MATH, demonstrating the method's practical impact.
Self-training has emerged as a technique for enhancing the mathematical reasoning capabilities of LLMs by leveraging model-generated reasoning paths supervised only by final answers. While this avoids the need for external models or detailed human annotations, optimizing the utilization of this self-generated data remains a challenge. The paper "Entropy-Based Adaptive Weighting for Self-Training" (2503.23913) introduces EAST, a method designed to improve self-training efficacy by adaptively weighting training samples based on the model's uncertainty.
EAST Methodology
The core principle of EAST is to prioritize training samples where the model exhibits higher uncertainty during the self-training process. The rationale is that samples inducing higher model uncertainty are potentially more informative or challenging, and focusing on them can accelerate learning and improve generalization for complex reasoning tasks.
EAST quantifies model uncertainty using the entropy of the model's predictive distribution over the generated reasoning path (or solution). For a given input question q and a model-generated solution s, let P(s∣q;θ) be the probability assigned by the model with parameters θ. The entropy H(s∣q;θ) associated with this generation is used as a measure of uncertainty.
An adaptive weighting function w is then defined based on this entropy. This function maps the calculated entropy H to a weight w(H), assigning higher weights to samples with higher entropy. The paper proposes a specific mapping function involving a tunable hyperparameter β:
w(H)=f(H;β)
While the exact function isn't specified in the abstract, a common form for such entropy-based weighting could be related to exponentiated entropy, potentially normalized or scaled, where β controls the sharpness or sensitivity of the weighting to entropy variations. For instance, a function like w(H)∝exp(βH) or a sigmoid-like function applied to normalized entropy could achieve this. A higher β value would lead to a steeper increase in weight as entropy rises, thus more strongly emphasizing uncertain samples. Conversely, β=0 might correspond to uniform weighting (vanilla self-training).
The overall self-training process with EAST involves:
- Data Generation: The LLM generates potential solutions (reasoning paths) for a set of training problems.
- Filtering: Generated solutions are filtered based on whether they reach the correct final answer. Only correctly answered problems are retained for training.
- Uncertainty Estimation: For each correctly solved problem (q,scorrect), the model calculates the entropy H(scorrect∣q;θ) associated with generating the correct solution.
- Weight Calculation: Using the entropy H and the chosen weighting function f(H;β), a weight w(H) is computed for each sample.
- Weighted Training: The model is fine-tuned on the set of correctly solved problems, where the loss for each sample is multiplied by its corresponding weight w(H).
This weighted loss guides the optimization process to allocate more gradient influence to samples the model finds more uncertain (higher entropy), under the assumption that these represent key learning opportunities.
Implementation Details
Implementing EAST requires integrating the entropy calculation and weighting mechanism into a standard self-training loop.
Entropy Calculation:
The entropy H(s∣q;θ) for a generated sequence s=(y1,y2,...,yT) can be calculated based on the model's conditional probabilities at each step. A common approach is to compute the average negative log-likelihood (NLL) or cross-entropy loss over the sequence, which is directly related to entropy. Assuming an autoregressive model, the probability is P(s∣q;θ)=t=1∏TP(yt∣y<t,q;θ). The sequence-level entropy can be approximated or related to the sum or average of token-level entropies:
H(s∣q;θ)≈−T1t=1∑TlogP(yt∣y<t,q;θ)
Alternatively, entropy could be calculated over the vocabulary distribution at each step and averaged, or sequence-level probability/perplexity could be used. The specific implementation choice might depend on computational trade-offs and empirical performance.
Weighted Loss Function:
During the fine-tuning phase, the standard cross-entropy loss for a sample (q,scorrect) is modified. Let the standard loss be L(q,scorrect;θ). The EAST loss LEAST incorporates the calculated weight w(H):
LEAST(q,scorrect;θ)=w(H(scorrect∣q;θ))×L(q,scorrect;θ)
The total loss over a batch B of correctly solved problems is the weighted average:
Lbatch=∑(q,s)∈Bw(H(s∣q))1(q,s)∈B∑w(H(s∣q))L(q,s;θ)
(Normalization might vary depending on the implementation).
Pseudocode for EAST Self-Training Iteration:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
for iteration in range(K): GeneratedData = [] # 1. Data Generation for q in D_train: # Generate N solutions for question q using model M Solutions = M.generate(q, num_return_sequences=N_generate) GeneratedData.append((q, Solutions)) CorrectSamples = [] # 2. Filtering & 3. Uncertainty Estimation for q, Solutions in GeneratedData: for s in Solutions: if verify_correctness(s): # Check if final answer is correct # Calculate entropy H for the correct solution s # This often involves a forward pass to get token probabilities log_probs = M.get_log_probs(q, s) H = -log_probs.mean() # Example: average negative log-likelihood CorrectSamples.append({'question': q, 'solution': s, 'entropy': H}) # 4. Weight Calculation # Normalize entropy values or apply directly # max_H = max(sample['entropy'] for sample in CorrectSamples) # Optional normalization # min_H = min(sample['entropy'] for sample in CorrectSamples) for sample in CorrectSamples: H_normalized = sample['entropy'] # Potentially normalize H # Calculate weight using the chosen function f(H; beta) sample['weight'] = calculate_weight(H_normalized, Beta) # 5. Weighted Training # Prepare dataloader with CorrectSamples, including weights dataloader = create_weighted_dataloader(CorrectSamples) # Fine-tune model M using weighted loss optimizer = AdamW(M.parameters()) for batch in dataloader: optimizer.zero_grad() # Calculate standard loss (e.g., cross-entropy) loss = M.calculate_loss(batch['question'], batch['solution']) # Apply weights weighted_loss = (loss * batch['weight']).mean() # Or sum, depending on loss definition weighted_loss.backward() optimizer.step() # Update model M for the next iteration M.save_checkpoint() |
Experimental Setup and Results
The EAST method was evaluated on standard mathematical reasoning benchmarks: GSM8K (grade school math word problems) and MATH (more challenging competition mathematics problems). The specific backbone LLMs used are not detailed in the abstract but are likely standard pre-trained models commonly employed in reasoning tasks.
The empirical results demonstrate the benefit of the adaptive weighting strategy:
- GSM8K: EAST achieved a 1-2% performance increase compared to vanilla self-training (uniform weighting). This suggests that prioritizing uncertain samples is beneficial for improving performance on these types of problems.
- MATH: The improvement was more pronounced here relative to the baseline. While vanilla self-training reportedly yielded virtually no improvement (0%) over the backbone model on the challenging MATH dataset, EAST achieved approximately a 1% gain over the backbone. This highlights that EAST can unlock performance gains in scenarios where simpler self-training approaches stagnate, particularly on more complex tasks.
These results indicate that the uncertainty-based weighting is effective, especially for harder problem distributions where the model's uncertainty signals are potentially more discriminative of valuable learning examples.
Practical Considerations
- Tuning β: The hyperparameter β controls the sharpness of the weighting. Its optimal value likely depends on the dataset, the model, and the specific entropy calculation used. Tuning β (e.g., via a validation set) is crucial. Too low a β approximates uniform weighting, while too high a β might overly focus on a small subset of very uncertain samples, potentially leading to instability or overfitting to noisy entropy estimates.
- Entropy Calculation Fidelity: The accuracy and stability of the entropy estimation H(s∣q;θ) are important. Noisy or poorly calibrated probability estimates could lead to suboptimal weighting. Using well-calibrated models or robust entropy estimation techniques is beneficial.
- Computational Overhead: EAST introduces additional computational cost primarily during the uncertainty estimation phase, which requires a forward pass through the model for each correctly generated solution to compute its probability or related metrics. The weighting calculation itself is typically negligible. However, this overhead is often acceptable within the context of the larger cost of data generation and model fine-tuning in self-training regimes.
- Interaction with Filtering: EAST operates on samples already filtered for correctness. The interaction between the filtering mechanism (which selects for success) and the weighting mechanism (which selects for uncertainty among successes) is key to its function. It focuses learning on problems the model can solve but finds difficult.
Conclusion
EAST presents an adaptive weighting strategy for self-training LLMs on mathematical reasoning tasks. By leveraging model uncertainty, quantified via entropy, to assign higher weights to more challenging or informative samples that the model correctly solves, EAST enhances the effectiveness of self-generated data. Empirical results on GSM8K and MATH show measurable improvements over vanilla self-training, particularly demonstrating gains on the difficult MATH benchmark where standard self-training showed limited benefit. This highlights the potential of uncertainty-aware data utilization strategies within self-supervised learning paradigms for complex reasoning.
Follow-up Questions
We haven't generated follow-up questions for this paper yet.