Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
169 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
45 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Power Lines: Scaling Laws for Weight Decay and Batch Size in LLM Pre-training (2505.13738v1)

Published 19 May 2025 in cs.LG, cs.AI, and cs.CL

Abstract: Efficient LLM pre-training requires well-tuned hyperparameters (HPs), including learning rate {\eta} and weight decay {\lambda}. We study scaling laws for HPs: formulas for how to scale HPs as we scale model size N, dataset size D, and batch size B. Recent work suggests the AdamW timescale, B/({\eta}{\lambda}D), should remain constant across training settings, and we verify the implication that optimal {\lambda} scales linearly with B, for a fixed N,D. However, as N,D scale, we show the optimal timescale obeys a precise power law in the tokens-per-parameter ratio, D/N. This law thus provides a method to accurately predict {\lambda}opt in advance of large-scale training. We also study scaling laws for optimal batch size Bopt (the B enabling lowest loss at a given N,D) and critical batch size Bcrit (the B beyond which further data parallelism becomes ineffective). In contrast with prior work, we find both Bopt and Bcrit scale as power laws in D, independent of model size, N. Finally, we analyze how these findings inform the real-world selection of Pareto-optimal N and D under dual training time and compute objectives.

Summary

  • The paper shows that tuning weight decay via a derived power law on the AdamW timescale is key for efficient LLM pre-training.
  • It finds that optimal and critical batch sizes scale as power laws in dataset size (approximately D^0.4 and D^0.5), independent of model size.
  • The study offers a practical roadmap using μP and loss-data scaling laws to balance training time and compute costs in large-scale training.

Efficiently pre-training LLMs at scale requires careful tuning of hyperparameters like learning rate (η\eta) and weight decay (λ\lambda). Traditional hyperparameter sweeping is often infeasible for the largest models due to computational costs. This paper presents scaling laws derived from hundreds of training runs to guide the selection of weight decay, batch size, model size (NN), and dataset size (DD) for optimal training performance, particularly focusing on practical tradeoffs between training time and total compute.

The paper investigates the scaling behavior of the AdamW timescale, defined as Tepoch=B/(ηλD)\mathcal{T}_{epoch} = B / (\eta\lambda D). This metric represents the effective fraction of the training data over which weight updates are averaged. While previous work suggested keeping this constant for multi-epoch training, this paper demonstrates that for LLM pre-training (typically one epoch), the optimal Tepoch\mathcal{T}_{epoch} is not constant but follows a precise power law in the tokens-per-parameter ratio (TPP=D/NTPP = D/N). The fitted law shows TepochTPP0.5\mathcal{T}_{epoch} \propto TPP^{-0.5}, meaning the optimal timescale decreases as models are trained on more data relative to their size.

A key practical implication is that, when using the AdamW optimizer and the Maximal Update Parameterization (μ\muP) framework to set the learning rate, the weight decay (λ\lambda) should be the primary hyperparameter adjusted as batch size (BB) and dataset size (DD) change. The paper shows empirically that tuning λ\lambda to maintain the optimal Tepoch\mathcal{T}_{epoch} is more effective than tuning η\eta as BB or DD varies. This provides a concrete recipe for practitioners: use μ\muP to set η\eta based on model size NN, and then set λ\lambda using the derived scaling law for Tepoch\mathcal{T}_{epoch} and the formula λopt=BηDTepoch(D/N)\lambda_{opt} = \frac{B}{\eta \cdot D \cdot \mathcal{T}_{epoch}(D/N)}. The linear relationship between optimal λ\lambda and BB for fixed N,DN, D holds up to a certain batch size, further supporting adjusting λ\lambda with BB.

The paper also provides insights into optimal batch size (BoptB_{opt}) and critical batch size (BcritB_{crit}). BoptB_{opt} is the batch size that achieves the lowest loss for a given NN and DD. BcritB_{crit} is defined based on an empirical model of the tradeoff between the number of tokens (DD) and the number of optimization steps (SS) required to reach a target loss LL. The model is expressed as S/Smin1=(D/Dmin1)1S/S_{min} - 1 = (D/D_{min} - 1)^{-1}, where DminD_{min} and SminS_{min} are the minimum tokens and steps, respectively, and Bcrit=Dmin/SminB_{crit} = D_{min}/S_{min}. The paper introduces a novel method to estimate BcritB_{crit} by fitting batch-size-specific loss-data scaling laws and interpolating the data needed for a target loss, avoiding the need for dense checkpoint evaluation or constant learning rates.

Contrary to some prior work that suggested BoptB_{opt} and BcritB_{crit} scale primarily with total compute (CC) or target loss (LL), this paper finds that both BoptB_{opt} and BcritB_{crit} scale as power laws in the dataset size DD, largely independent of model size NN. Specifically, the findings suggest BoptD0.4B_{opt} \propto D^{0.4} and BcritD0.5B_{crit} \propto D^{0.5}. This aligns with recent concurrent work (2410.21676), reinforcing the fundamental dependence of optimal and critical batch sizes on the amount of data used. Practically, this means practitioners can estimate these values from small-scale runs and extrapolate based on the training data size. The DSD-S tradeoff equation D=Dmin(1+B/Bcrit)D = D_{min}(1 + B/B_{crit}) can then be used to understand the computational cost (proportional to DD) and training time (proportional to S=D/BS=D/B and NN) implications of choosing a particular batch size BB for a given NN and target loss LL.

Using these derived scaling laws, the paper analyzes the Pareto-optimal configurations for N,D,BN, D, B to achieve a target loss LL while balancing training time and compute. For a fixed total compute budget, traditional advice suggests training at roughly 20 TPP to minimize loss. However, considering training time (which decreases with larger BB, but increases total DD and compute via the DSD-S tradeoff), the analysis reveals that smaller, over-trained models (TPP > 20) can be Pareto-optimal. This is because overtrained models are trained on larger datasets, leading to higher BcritB_{crit} and thus allowing for more efficient use of larger batch sizes to reduce training time, even if the total compute is higher than the 20 TPP optimum. It is shown to be Pareto-inefficient to target 20 TPP actual training TPP when using very large batch sizes (BBoptB \gg B_{opt}).

The core practical takeaways are:

  • When using AdamW and μ\muP, fix the learning rate based on model width and tune weight decay (λ\lambda) based on the derived Tepoch\mathcal{T}_{epoch} scaling law ((D/N)0.5\propto (D/N)^{-0.5}) and batch size (BB).
  • Optimal and critical batch sizes (Bopt,BcritB_{opt}, B_{crit}) scale with dataset size DD, not total compute CC or target loss LL. Estimate their scaling from small runs (D0.4\approx D^{0.4} and D0.5\approx D^{0.5}) and use these laws to select BB for large-scale training.
  • Leverage the DSD-S tradeoff model and the BcritB_{crit} scaling law to plan training runs that optimally balance total compute (FLOPs) and training time, especially considering the benefits of larger batch sizes. This may favor smaller, over-trained models for faster training times at a given performance level.

Implementation requires:

  1. Train a proxy model with μ\muP to find base hyperparameters, including ηbase\eta_{base}.
  2. For target model width WW, set peak η=ηbase(Wproxy/W)\eta = \eta_{base} \cdot (W_{proxy}/W).
  3. From limited small-scale experiments across various N,D,B,λN, D, B, \lambda and losses, fit:
    • The optimal Tepoch\mathcal{T}_{epoch} power law with TPP [(2505.13738), Eq. 3].
    • The BB-specific loss-data power laws LB(D)L_B(D) [(2505.13738), Fig. 4].
    • The DSD-S tradeoff curve S/Smin1=(D/Dmin1)1S/S_{min} - 1 = (D/D_{min} - 1)^{-1} [(2505.13738), Eq. 6], yielding DminD_{min} and SminS_{min} for various losses. Calculate Bcrit=Dmin/SminB_{crit} = D_{min}/S_{min}.
    • The BcritB_{crit} power law with DminD_{min} [(2505.13738), Eq. 7].
  4. For a desired target loss LL, model size NN, and dataset size DD:
    • Predict the target DminD_{min} for loss LL at size NN (e.g., using a loss scaling law like [(2203.02155), Eq. 1]).
    • Predict BcritB_{crit} for this DminD_{min} using the fitted power law.
    • Choose a batch size BB. The required dataset size will be D=Dmin(1+B/Bcrit)D = D_{min}(1 + B/B_{crit}). Ensure the actual dataset used is at least this size.
    • Calculate the required steps S=D/BS = D/B. Configure the LR schedule accordingly.
    • Calculate the target Tepoch\mathcal{T}_{epoch} for the chosen D/ND/N ratio.
    • Set λ=B/(ηDTepoch)\lambda = B / (\eta \cdot D \cdot \mathcal{T}_{epoch}).

This systematic approach, grounded in empirical scaling laws, offers a more predictable and efficient way to navigate the complex hyperparameter space for large-scale LLM pre-training, enabling better control over training costs and time. The findings also highlight that achieving the fastest training time for a given performance level might involve training models on significantly more data than the compute-optimal minimum.

Some limitations mentioned include the focus on AdamW and a specific LR schedule shape, the need for more data/architectural variations in future work, and the empirical observation that small batches degrade performance more than theory suggests, potentially requiring tuning beyond λ\lambda. Practical systems constraints of very large batches are also not explicitly modeled.