- The paper shows that tuning weight decay via a derived power law on the AdamW timescale is key for efficient LLM pre-training.
- It finds that optimal and critical batch sizes scale as power laws in dataset size (approximately D^0.4 and D^0.5), independent of model size.
- The study offers a practical roadmap using μP and loss-data scaling laws to balance training time and compute costs in large-scale training.
Efficiently pre-training LLMs at scale requires careful tuning of hyperparameters like learning rate (η) and weight decay (λ). Traditional hyperparameter sweeping is often infeasible for the largest models due to computational costs. This paper presents scaling laws derived from hundreds of training runs to guide the selection of weight decay, batch size, model size (N), and dataset size (D) for optimal training performance, particularly focusing on practical tradeoffs between training time and total compute.
The paper investigates the scaling behavior of the AdamW timescale, defined as Tepoch=B/(ηλD). This metric represents the effective fraction of the training data over which weight updates are averaged. While previous work suggested keeping this constant for multi-epoch training, this study demonstrates that for LLM pre-training (typically one epoch), the optimal Tepoch is not constant but follows a precise power law in the tokens-per-parameter ratio (TPP=D/N). The fitted law shows Tepoch∝TPP−0.5, meaning the optimal timescale decreases as models are trained on more data relative to their size.
A key practical implication is that, when using the AdamW optimizer and the Maximal Update Parameterization (μP) framework to set the learning rate, the weight decay (λ) should be the primary hyperparameter adjusted as batch size (λ0) and dataset size (λ1) change. The paper shows empirically that tuning λ2 to maintain the optimal λ3 is more effective than tuning λ4 as λ5 or λ6 varies. This provides a concrete recipe for practitioners: use λ7P to set λ8 based on model size λ9, and then set N0 using the derived scaling law for N1 and the formula N2. The linear relationship between optimal N3 and N4 for fixed N5 holds up to a certain batch size, further supporting adjusting N6 with N7.
The study also provides insights into optimal batch size (N8) and critical batch size (N9). D0 is the batch size that achieves the lowest loss for a given D1 and D2. D3 is defined based on an empirical model of the tradeoff between the number of tokens (D4) and the number of optimization steps (D5) required to reach a target loss D6. The model is expressed as D7, where D8 and D9 are the minimum tokens and steps, respectively, and Tepoch=B/(ηλD)0. The paper introduces a novel method to estimate Tepoch=B/(ηλD)1 by fitting batch-size-specific loss-data scaling laws and interpolating the data needed for a target loss, avoiding the need for dense checkpoint evaluation or constant learning rates.
Contrary to some prior work that suggested Tepoch=B/(ηλD)2 and Tepoch=B/(ηλD)3 scale primarily with total compute (Tepoch=B/(ηλD)4) or target loss (Tepoch=B/(ηλD)5), this paper finds that both Tepoch=B/(ηλD)6 and Tepoch=B/(ηλD)7 scale as power laws in the dataset size Tepoch=B/(ηλD)8, largely independent of model size Tepoch=B/(ηλD)9. Specifically, the findings suggest Tepoch0 and Tepoch1. This aligns with recent concurrent work (Zhang et al., 2024), reinforcing the fundamental dependence of optimal and critical batch sizes on the amount of data used. Practically, this means practitioners can estimate these values from small-scale runs and extrapolate based on the training data size. The Tepoch2 tradeoff equation Tepoch3 can then be used to understand the computational cost (proportional to Tepoch4) and training time (proportional to Tepoch5 and Tepoch6) implications of choosing a particular batch size Tepoch7 for a given Tepoch8 and target loss Tepoch9.
Using these derived scaling laws, the paper analyzes the Pareto-optimal configurations for TPP=D/N0 to achieve a target loss TPP=D/N1 while balancing training time and compute. For a fixed total compute budget, traditional advice suggests training at roughly 20 TPP to minimize loss. However, considering training time (which decreases with larger TPP=D/N2, but increases total TPP=D/N3 and compute via the TPP=D/N4 tradeoff), the analysis reveals that smaller, over-trained models (TPP > 20) can be Pareto-optimal. This is because overtrained models are trained on larger datasets, leading to higher TPP=D/N5 and thus allowing for more efficient use of larger batch sizes to reduce training time, even if the total compute is higher than the 20 TPP optimum. It is shown to be Pareto-inefficient to target 20 TPP actual training TPP when using very large batch sizes (TPP=D/N6).
The core practical takeaways are:
- When using AdamW and TPP=D/N7P, fix the learning rate based on model width and tune weight decay (TPP=D/N8) based on the derived TPP=D/N9 scaling law (Tepoch∝TPP−0.50) and batch size (Tepoch∝TPP−0.51).
- Optimal and critical batch sizes (Tepoch∝TPP−0.52) scale with dataset size Tepoch∝TPP−0.53, not total compute Tepoch∝TPP−0.54 or target loss Tepoch∝TPP−0.55. Estimate their scaling from small runs (Tepoch∝TPP−0.56 and Tepoch∝TPP−0.57) and use these laws to select Tepoch∝TPP−0.58 for large-scale training.
- Leverage the Tepoch∝TPP−0.59 tradeoff model and the μ0 scaling law to plan training runs that optimally balance total compute (FLOPs) and training time, especially considering the benefits of larger batch sizes. This may favor smaller, over-trained models for faster training times at a given performance level.
Implementation requires:
- Train a proxy model with μ1P to find base hyperparameters, including μ2.
- For target model width μ3, set peak μ4.
- From limited small-scale experiments across various μ5 and losses, fit:
- The optimal μ6 power law with TPP [(2505.13738), Eq. 3].
- The μ7-specific loss-data power laws μ8 [(2505.13738), Fig. 4].
- The μ9 tradeoff curve λ0 [(2505.13738), Eq. 6], yielding λ1 and λ2 for various losses. Calculate λ3.
- The λ4 power law with λ5 [(2505.13738), Eq. 7].
- For a desired target loss λ6, model size λ7, and dataset size λ8:
- Predict the target λ9 for loss λ00 at size λ01 (e.g., using a loss scaling law like [(Ouyang et al., 2022), Eq. 1]).
- Predict λ02 for this λ03 using the fitted power law.
- Choose a batch size λ04. The required dataset size will be λ05. Ensure the actual dataset used is at least this size.
- Calculate the required steps λ06. Configure the LR schedule accordingly.
- Calculate the target λ07 for the chosen λ08 ratio.
- Set λ09.
This systematic approach, grounded in empirical scaling laws, offers a more predictable and efficient way to navigate the complex hyperparameter space for large-scale LLM pre-training, enabling better control over training costs and time. The findings also highlight that achieving the fastest training time for a given performance level might involve training models on significantly more data than the compute-optimal minimum.
Some limitations mentioned include the focus on AdamW and a specific LR schedule shape, the need for more data/architectural variations in future work, and the empirical observation that small batches degrade performance more than theory suggests, potentially requiring tuning beyond λ10. Practical systems constraints of very large batches are also not explicitly modeled.