MiLe Loss: Reweighting for Rare Token Learning
- MiLe Loss is a training objective that dynamically up-weights difficult, rare tokens to address the inherent token frequency imbalance in language models.
- It leverages the entropy of predicted token distributions as a proxy for difficulty, allowing the model to focus on ambiguous, informative tokens.
- Empirical results demonstrate that MiLe Loss improves performance on tasks like closed-book QA and common-sense reasoning with minimal computational overhead.
MiLe Loss is a training objective designed to mitigate the bias toward frequent and easy-to-learn tokens in generative LLM pretraining. Standard autoregressive LLMs, which learn next-token distributions via cross-entropy loss, tend to be dominated during optimization by the most common tokens due to the long-tailed, Zipfian nature of natural language corpora. MiLe Loss introduces a principled, entropy-driven approach to dynamically up-weight examples corresponding to harder, less frequent tokens, thereby directly addressing the imbalanced gradient signal observed in traditional maximum likelihood learning frameworks (Su et al., 2023).
1. Motivation: Token Frequency Imbalance in LLM Training
Large-scale text corpora, such as The Pile, exhibit strong frequency imbalance: a small fraction of tokens appears extremely frequently, with the rest forming a long tail of rare events. Empirical analysis on the Pile, using a LLaMA-7B model, shows that the most frequent 80% of tokens yield an average per-token perplexity (PPL) of , while the rarest 5% see PPL . This disparity reflects a genuine difference in learning difficulty: rare tokens are harder to predict and, under normal cross-entropy, are allocated proportionally less training signal. Consequently, generative models tend to under-perform on information-rich, infrequent tokens, which often contribute disproportionately to challenging downstream tasks.
2. MiLe Loss: Mathematical Formulation
Given model predictions over a vocabulary of size , and a ground-truth target index , MiLe Loss defines the per-token objective as
where
and is a focusing hyperparameter. The dynamic weight up-weights tokens associated with high-entropy (i.e., ambiguous or hard) predictive distributions. When , MiLe Loss reduces to standard cross-entropy.
| Term | Definition | Role in MiLe Loss |
|---|---|---|
| Softmax distribution over tokens | Predictive probabilities | |
| Information entropy of | Quantifies uncertainty/difficulty | |
| Focusing hyperparameter | Controls sensitivity to difficulty |
3. Entropy as a Proxy for Token-Level Difficulty
MiLe Loss leverages the information entropy, , calculated from the predicted token distribution, rather than simply the likelihood or logit assigned to the correct token. This decision reflects the multi-label ambiguity of natural language: many next-token predictions may be valid, and concentration of probability mass indicates model confidence. High-entropy output indicates that the model is confused or that the next token is genuinely ambiguous or rare; such cases are empirically correlated with harder-to-learn tokens. Previous approaches such as Focal Loss, which reweigh solely via , cannot disambiguate between mispredictions due to genuine ambiguity versus model error. By reweighting based on total entropy, MiLe Loss selectively focuses on those training instances where the model is genuinely uncertain.
4. Dynamic Scaling and Training Workflow
The training procedure with MiLe Loss requires, for each prediction:
- Compute the model logits and softmax probabilities .
- Calculate the relative entropy .
- Compute a scaling factor (with clamping to ensure non-negativity).
- Multiply the standard cross-entropy by .
- Aggregate and backpropagate as usual.
This approach is hyperparameterized by ; in the original study, is the default, and ablations show robust performance for up to approximately $2.0$ without over-focusing. The following pseudocode outlines the MiLe Loss integration in transformer-based autoregressive models:
1 2 3 4 5 6 7 8 9 10 11 |
for each training batch of token sequences: logits = TransformerLM(x[:, :T-1]) # [B, T-1, V] probs = softmax(logits) # [B, T-1, V] p_true = probs.gather(dim=2, index=target.unsqueeze(-1)).squeeze(-1) ent = -sum(probs * log(probs), dim=-1) # [B, T-1] w = (1 - ent).clamp(min=0) ** gamma ce = -log(p_true) loss = (w * ce).mean() optimizer.zero_grad() loss.backward() optimizer.step() |
MiLe Loss imposes negligible computational overhead beyond per-token entropy calculation and vectorized weighting (Su et al., 2023).
5. Experimental Protocol and Implementation Details
MiLe Loss was validated by pretraining autoregressive LLMs—at 468M, 1.2B, and 6.7B parameter scales—on The Pile (825 GB of English text, 1,024-token blocks), using the LLaMA tokenizer (32K vocabulary) and AdamW optimizer with a learning rate of , 2K-step warmup, and cosine decay. Models were trained to convergence on 100B tokens. MiLe Loss was applied with unless otherwise indicated.
6. Empirical Results and Analysis
Across scales and diverse benchmarks, MiLe Loss provides consistent improvements:
- Zero-/Few-shot Common-sense Reasoning (8 datasets, accuracy):
- 468M: CE 49.14% MiLe 49.93%
- 1.2B: CE 51.69% MiLe 52.48%
- 6.7B: CE 57.59% MiLe 57.97%
- Closed-book QA (6.7B, exact match):
- TriviaQA, 0-shot: CE 17.09 MiLe 20.64
- WebQuestions, 5-shot: CE 14.17 MiLe 14.57
- Massive Multitask Language Understanding (MMLU, 6.7B, 5-shot):
- CE 29.38% MiLe 29.68%
- Per-token Perplexity by Frequency (6.7B, ):
- Easy (top 80%): PPL 4.323 4.349
- Medium (next 15%): 13.541 13.459
- Hard (bottom 5%): 15.517 15.371
- Longer Pretraining (200B tokens): CE 61.75% MiLe 63.08% (5-shot, +1.33 improvement).
Performance gains are most pronounced for rare tokens, with a minor tradeoff in slightly higher perplexity on the most common tokens. MiLe Loss also yields heightened accuracy on challenging, information-rich QA datasets.
7. Ablation, Interpretability, and Limitations
Ablation of the focusing hyperparameter reveals that gains are robust up to ; over-focusing at higher values degrades generalization, plausibly due to over-weighting highly ambiguous positions. Perplexity breakdowns demonstrate that MiLe Loss systematically reallocates modeling capacity towards hard tokens and reduces excessive graduation toward frequent, easy predictions. Observed improvements in closed-book QA and common-sense reasoning benchmarks suggest MiLe Loss more effectively encodes rare, semantically informative content. These results confirm that a simple change to loss weighting, grounded in the entropy of the predicted distribution, can significantly rebalance learning in LLMs, ameliorating a foundational issue of token-level data imbalance (Su et al., 2023).