- The paper shows that tuning weight decay via a derived power law on the AdamW timescale is key for efficient LLM pre-training.
- It finds that optimal and critical batch sizes scale as power laws in dataset size (approximately D^0.4 and D^0.5), independent of model size.
- The study offers a practical roadmap using μP and loss-data scaling laws to balance training time and compute costs in large-scale training.
Efficiently pre-training LLMs at scale requires careful tuning of hyperparameters like learning rate (η) and weight decay (λ). Traditional hyperparameter sweeping is often infeasible for the largest models due to computational costs. This paper presents scaling laws derived from hundreds of training runs to guide the selection of weight decay, batch size, model size (N), and dataset size (D) for optimal training performance, particularly focusing on practical tradeoffs between training time and total compute.
The paper investigates the scaling behavior of the AdamW timescale, defined as Tepoch=B/(ηλD). This metric represents the effective fraction of the training data over which weight updates are averaged. While previous work suggested keeping this constant for multi-epoch training, this paper demonstrates that for LLM pre-training (typically one epoch), the optimal Tepoch is not constant but follows a precise power law in the tokens-per-parameter ratio (TPP=D/N). The fitted law shows Tepoch∝TPP−0.5, meaning the optimal timescale decreases as models are trained on more data relative to their size.
A key practical implication is that, when using the AdamW optimizer and the Maximal Update Parameterization (μP) framework to set the learning rate, the weight decay (λ) should be the primary hyperparameter adjusted as batch size (B) and dataset size (D) change. The paper shows empirically that tuning λ to maintain the optimal Tepoch is more effective than tuning η as B or D varies. This provides a concrete recipe for practitioners: use μP to set η based on model size N, and then set λ using the derived scaling law for Tepoch and the formula λopt=η⋅D⋅Tepoch(D/N)B. The linear relationship between optimal λ and B for fixed N,D holds up to a certain batch size, further supporting adjusting λ with B.
The paper also provides insights into optimal batch size (Bopt) and critical batch size (Bcrit). Bopt is the batch size that achieves the lowest loss for a given N and D. Bcrit is defined based on an empirical model of the tradeoff between the number of tokens (D) and the number of optimization steps (S) required to reach a target loss L. The model is expressed as S/Smin−1=(D/Dmin−1)−1, where Dmin and Smin are the minimum tokens and steps, respectively, and Bcrit=Dmin/Smin. The paper introduces a novel method to estimate Bcrit by fitting batch-size-specific loss-data scaling laws and interpolating the data needed for a target loss, avoiding the need for dense checkpoint evaluation or constant learning rates.
Contrary to some prior work that suggested Bopt and Bcrit scale primarily with total compute (C) or target loss (L), this paper finds that both Bopt and Bcrit scale as power laws in the dataset size D, largely independent of model size N. Specifically, the findings suggest Bopt∝D0.4 and Bcrit∝D0.5. This aligns with recent concurrent work (2410.21676), reinforcing the fundamental dependence of optimal and critical batch sizes on the amount of data used. Practically, this means practitioners can estimate these values from small-scale runs and extrapolate based on the training data size. The D−S tradeoff equation D=Dmin(1+B/Bcrit) can then be used to understand the computational cost (proportional to D) and training time (proportional to S=D/B and N) implications of choosing a particular batch size B for a given N and target loss L.
Using these derived scaling laws, the paper analyzes the Pareto-optimal configurations for N,D,B to achieve a target loss L while balancing training time and compute. For a fixed total compute budget, traditional advice suggests training at roughly 20 TPP to minimize loss. However, considering training time (which decreases with larger B, but increases total D and compute via the D−S tradeoff), the analysis reveals that smaller, over-trained models (TPP > 20) can be Pareto-optimal. This is because overtrained models are trained on larger datasets, leading to higher Bcrit and thus allowing for more efficient use of larger batch sizes to reduce training time, even if the total compute is higher than the 20 TPP optimum. It is shown to be Pareto-inefficient to target 20 TPP actual training TPP when using very large batch sizes (B≫Bopt).
The core practical takeaways are:
- When using AdamW and μP, fix the learning rate based on model width and tune weight decay (λ) based on the derived Tepoch scaling law (∝(D/N)−0.5) and batch size (B).
- Optimal and critical batch sizes (Bopt,Bcrit) scale with dataset size D, not total compute C or target loss L. Estimate their scaling from small runs (≈D0.4 and ≈D0.5) and use these laws to select B for large-scale training.
- Leverage the D−S tradeoff model and the Bcrit scaling law to plan training runs that optimally balance total compute (FLOPs) and training time, especially considering the benefits of larger batch sizes. This may favor smaller, over-trained models for faster training times at a given performance level.
Implementation requires:
- Train a proxy model with μP to find base hyperparameters, including ηbase.
- For target model width W, set peak η=ηbase⋅(Wproxy/W).
- From limited small-scale experiments across various N,D,B,λ and losses, fit:
- The optimal Tepoch power law with TPP [(2505.13738), Eq. 3].
- The B-specific loss-data power laws LB(D) [(2505.13738), Fig. 4].
- The D−S tradeoff curve S/Smin−1=(D/Dmin−1)−1 [(2505.13738), Eq. 6], yielding Dmin and Smin for various losses. Calculate Bcrit=Dmin/Smin.
- The Bcrit power law with Dmin [(2505.13738), Eq. 7].
- For a desired target loss L, model size N, and dataset size D:
- Predict the target Dmin for loss L at size N (e.g., using a loss scaling law like [(2203.02155), Eq. 1]).
- Predict Bcrit for this Dmin using the fitted power law.
- Choose a batch size B. The required dataset size will be D=Dmin(1+B/Bcrit). Ensure the actual dataset used is at least this size.
- Calculate the required steps S=D/B. Configure the LR schedule accordingly.
- Calculate the target Tepoch for the chosen D/N ratio.
- Set λ=B/(η⋅D⋅Tepoch).
This systematic approach, grounded in empirical scaling laws, offers a more predictable and efficient way to navigate the complex hyperparameter space for large-scale LLM pre-training, enabling better control over training costs and time. The findings also highlight that achieving the fastest training time for a given performance level might involve training models on significantly more data than the compute-optimal minimum.
Some limitations mentioned include the focus on AdamW and a specific LR schedule shape, the need for more data/architectural variations in future work, and the empirical observation that small batches degrade performance more than theory suggests, potentially requiring tuning beyond λ. Practical systems constraints of very large batches are also not explicitly modeled.