- The paper proposes a quantitative scaling law that characterizes CPT loss curves by integrating distribution shift and learning rate annealing effects.
- It finds that model size and replay ratio predictably influence both general and downstream losses, enabling hyperparameter optimization.
- The framework further predicts out-of-domain loss by approximating it as a linear combination of general and downstream validation losses.
This paper explores the learning dynamics of LLMs during Continual Pre-Training (CPT). CPT is a widely used technique to adapt powerful foundation models to specific downstream domains (like coding, finance, or legal) without the prohibitive cost of training from scratch. The authors focus on understanding how performance, measured by validation loss on both general (Dpt) and downstream (Dcpt) datasets, evolves throughout the CPT process.
The core contribution is the development of a quantitative scaling law that describes the CPT loss curve. The authors observe that the CPT process represents a transition from a "hidden" pre-training curve on the original general dataset (Dpt) to another "hidden" pre-training curve on the new downstream dataset (Dcpt). This transition is characterized by two main factors:
- Distribution Shift: This describes the deviation of the CPT loss curve from the hidden PT curve on Dpt. It reflects the difference in data distribution between Dpt and Dcpt. The authors find that this shift follows a power-law form with respect to the forward area (sum of learning rates) accumulated during CPT. Notably, the magnitude of this distribution shift term appears independent of the specific checkpoint (transfer starting point) of the pre-trained model (Fig. 2).
- Learning Rate (LR) Annealing: Similar to standard pre-training dynamics, the LR schedule influences the loss curve by allowing for local drops in loss. This effect is incorporated using concepts from previous scaling laws that include an "annealing area" term, which depends on the history of learning rates.
Combining these two factors, the paper proposes a CPT scaling law (Eq. 4):
L(t)=Lbase​(t)+ΔL(t)
where Lbase​(t) represents the loss dynamics without distribution shift (following the PT scaling law with accumulated forward and annealing areas from both PT and CPT phases) and ΔL(t) is the power-law distribution shift term based on the CPT forward area. This formula includes parameters representing factors like base loss, learning rates, annealing effects, and distribution distance between datasets. The authors show that this law can accurately fit and predict loss curves for different LR schedules (constant, WSD, cosine) on both Dpt and Dept validation sets (Fig. 3, Fig. 10, Fig. 11).
The paper extends this scaling law to incorporate model size and replay ratio. For model size, they observe that the absolute value of the distribution shift term is similar across different tested model sizes, while the annealing effect scales with model size (Appendix E). The replay ratio, where a percentage of the original Dpt is mixed with Dept during CPT, is found to influence the distribution shift and annealing terms exponentially. Higher replay ratios lead to smaller distribution shifts and slower increases in Dpt loss (Fig. 6b, 6c, Appendix D, H, Fig. 19). The proposed unified formula (Eq. 8 in Appendix H) can describe the entire loss dynamics for different replay ratios, going beyond predicting only the final loss like some prior work.
A key practical application of this scaling law is hyperparameter optimization for CPT to balance performance on general and downstream tasks. The paper defines a combined objective function based on the increase in Dpt loss (ΔLDpt​) and decrease in Dept loss (ΔLDept​), weighted by coefficients λ1​ and λ2​ (Eq. 5). Using the scaling law, the authors analyze and predict the impact of key CPT factors on this balance:
- Loss Potential of PT models: Defined as the potential for future loss drop via LR annealing (related to the final LR of the PT phase). Models with higher loss potential (e.g., models where training was stopped before full annealing or with a larger final LR) achieve lower final Dept loss (Fig. 5). The paper recommends releasing high loss potential versions of open-source models for better CPT adaptation (Finding 3).
- Peak Learning Rate: A higher peak LR for CPT accelerates the decrease in Dept loss but also the increase in Dpt loss (Fig. 7a, 7b). The scaling law can predict the optimal peak LR for a given balance objective (Fig. 8b).
- Continual Pre-Training Steps: The number of CPT steps affects both losses. The Dpt loss might continuously rise, or rise then fall, depending on the initial model state ("critical point") and distribution distance (Fig. 7c). More training steps don't always mean better general abilities (Finding 4).
- Optimal Replay Ratio: The scaling law can predict the optimal mix of Dpt and Dept data to balance performance based on the λ1​,λ2​ coefficients (Fig. 8c).
Furthermore, the paper explores how to apply the framework to predict validation loss on Out-of-Domain (OOD) datasets (Dood) that are neither the original Dpt nor the new Dept. They propose a hypothesis that the OOD loss can be approximated as a linear combination of the Dpt and Dept validation losses (Eq. 6, Appendix J). The coefficients of this linear combination are specific to the OOD dataset and reflect its similarity to Dpt and Dept (Fig. 9, Fig. 21, Fig. 22). Once these coefficients are determined (e.g., by fitting to a few initial steps), predicting the OOD loss curve during CPT reduces to predicting the Dpt and Dept loss curves using the main scaling law, and optimizing OOD performance becomes equivalent to balancing Dpt and Dept losses using the derived coefficients (Finding 5).
Finally, the paper addresses the practical challenge of CPT using open-source pre-trained models where detailed training history (dataset distribution, exact training steps, loss potential) is often unknown. They propose simple methods to adapt their scaling law:
- Use a proxy general dataset (e.g., a common crawl subset like RedPajama-C4) as a stand-in for the unknown original Dpt validation set to measure general performance dynamics.
- Treat unknown PT parameters, such as the cumulative forward area (Spt), as parameters to be fitted from initial CPT steps.
- Assume the final PT LR is zero for models that are typically fully annealed for benchmarking.
They demonstrate the effectiveness of these methods by fitting and predicting the CPT loss curve for LLaMA3.2-1B using RedPajama-C4 as a proxy Dpt and Pile-of-Law as Dept (Fig. 18, Appendix G).
In essence, this work provides a data-driven, empirical framework to understand, model, and predict the dynamic behavior of LLMs during CPT. It offers practical tools and insights for hyperparameter optimization to navigate the trade-off between maintaining general capabilities and acquiring domain-specific knowledge, even when working with open-source models with incomplete histories. While acknowledging the empirical nature of the laws due to the complexity of LLM training, the paper presents extensive experimental evidence supporting the applicability and effectiveness of their proposed scaling law in practical CPT scenarios.