Papers
Topics
Authors
Recent
AI Research Assistant
AI Research Assistant
Well-researched responses based on relevant abstracts and paper content.
Custom Instructions Pro
Preferences or requirements that you'd like Emergent Mind to consider when generating responses.
Gemini 2.5 Flash
Gemini 2.5 Flash 75 tok/s
Gemini 2.5 Pro 46 tok/s Pro
GPT-5 Medium 26 tok/s Pro
GPT-5 High 27 tok/s Pro
GPT-4o 104 tok/s Pro
Kimi K2 170 tok/s Pro
GPT OSS 120B 468 tok/s Pro
Claude Sonnet 4 37 tok/s Pro
2000 character limit reached

Learning Dynamics in Continual Pre-Training for Large Language Models (2505.07796v2)

Published 12 May 2025 in cs.CL, cs.AI, and cs.LG

Abstract: Continual Pre-Training (CPT) has become a popular and effective method to apply strong foundation models to specific downstream tasks. In this work, we explore the learning dynamics throughout the CPT process for LLMs. We specifically focus on how general and downstream domain performance evolves at each training step, with domain performance measured via validation losses. We have observed that the CPT loss curve fundamentally characterizes the transition from one curve to another hidden curve, and could be described by decoupling the effects of distribution shift and learning rate annealing. We derive a CPT scaling law that combines the two factors, enabling the prediction of loss at any (continual) training steps and across learning rate schedules (LRS) in CPT. Our formulation presents a comprehensive understanding of several critical factors in CPT, including loss potential, peak learning rate, training steps, replay ratio, etc. Moreover, our approach can be adapted to customize training hyper-parameters to different CPT goals such as balancing general and domain-specific performance. Extensive experiments demonstrate that our scaling law holds across various CPT datasets and training hyper-parameters.

Summary

  • The paper proposes a quantitative scaling law that characterizes CPT loss curves by integrating distribution shift and learning rate annealing effects.
  • It finds that model size and replay ratio predictably influence both general and downstream losses, enabling hyperparameter optimization.
  • The framework further predicts out-of-domain loss by approximating it as a linear combination of general and downstream validation losses.

This paper explores the learning dynamics of LLMs during Continual Pre-Training (CPT). CPT is a widely used technique to adapt powerful foundation models to specific downstream domains (like coding, finance, or legal) without the prohibitive cost of training from scratch. The authors focus on understanding how performance, measured by validation loss on both general (Dpt) and downstream (Dcpt) datasets, evolves throughout the CPT process.

The core contribution is the development of a quantitative scaling law that describes the CPT loss curve. The authors observe that the CPT process represents a transition from a "hidden" pre-training curve on the original general dataset (Dpt) to another "hidden" pre-training curve on the new downstream dataset (Dcpt). This transition is characterized by two main factors:

  1. Distribution Shift: This describes the deviation of the CPT loss curve from the hidden PT curve on Dpt. It reflects the difference in data distribution between Dpt and Dcpt. The authors find that this shift follows a power-law form with respect to the forward area (sum of learning rates) accumulated during CPT. Notably, the magnitude of this distribution shift term appears independent of the specific checkpoint (transfer starting point) of the pre-trained model (Fig. 2).
  2. Learning Rate (LR) Annealing: Similar to standard pre-training dynamics, the LR schedule influences the loss curve by allowing for local drops in loss. This effect is incorporated using concepts from previous scaling laws that include an "annealing area" term, which depends on the history of learning rates.

Combining these two factors, the paper proposes a CPT scaling law (Eq. 4):

L(t)=Lbase(t)+ΔL(t)L(t) = L_{base}(t) + \Delta L(t)

where Lbase(t)L_{base}(t) represents the loss dynamics without distribution shift (following the PT scaling law with accumulated forward and annealing areas from both PT and CPT phases) and ΔL(t)\Delta L(t) is the power-law distribution shift term based on the CPT forward area. This formula includes parameters representing factors like base loss, learning rates, annealing effects, and distribution distance between datasets. The authors show that this law can accurately fit and predict loss curves for different LR schedules (constant, WSD, cosine) on both Dpt and Dept validation sets (Fig. 3, Fig. 10, Fig. 11).

The paper extends this scaling law to incorporate model size and replay ratio. For model size, they observe that the absolute value of the distribution shift term is similar across different tested model sizes, while the annealing effect scales with model size (Appendix E). The replay ratio, where a percentage of the original Dpt is mixed with Dept during CPT, is found to influence the distribution shift and annealing terms exponentially. Higher replay ratios lead to smaller distribution shifts and slower increases in Dpt loss (Fig. 6b, 6c, Appendix D, H, Fig. 19). The proposed unified formula (Eq. 8 in Appendix H) can describe the entire loss dynamics for different replay ratios, going beyond predicting only the final loss like some prior work.

A key practical application of this scaling law is hyperparameter optimization for CPT to balance performance on general and downstream tasks. The paper defines a combined objective function based on the increase in Dpt loss (ΔLDpt\Delta L_{Dpt}) and decrease in Dept loss (ΔLDept\Delta L_{Dept}), weighted by coefficients λ1\lambda_1 and λ2\lambda_2 (Eq. 5). Using the scaling law, the authors analyze and predict the impact of key CPT factors on this balance:

  • Loss Potential of PT models: Defined as the potential for future loss drop via LR annealing (related to the final LR of the PT phase). Models with higher loss potential (e.g., models where training was stopped before full annealing or with a larger final LR) achieve lower final Dept loss (Fig. 5). The paper recommends releasing high loss potential versions of open-source models for better CPT adaptation (Finding 3).
  • Peak Learning Rate: A higher peak LR for CPT accelerates the decrease in Dept loss but also the increase in Dpt loss (Fig. 7a, 7b). The scaling law can predict the optimal peak LR for a given balance objective (Fig. 8b).
  • Continual Pre-Training Steps: The number of CPT steps affects both losses. The Dpt loss might continuously rise, or rise then fall, depending on the initial model state ("critical point") and distribution distance (Fig. 7c). More training steps don't always mean better general abilities (Finding 4).
  • Optimal Replay Ratio: The scaling law can predict the optimal mix of Dpt and Dept data to balance performance based on the λ1,λ2\lambda_1, \lambda_2 coefficients (Fig. 8c).

Furthermore, the paper explores how to apply the framework to predict validation loss on Out-of-Domain (OOD) datasets (Dood) that are neither the original Dpt nor the new Dept. They propose a hypothesis that the OOD loss can be approximated as a linear combination of the Dpt and Dept validation losses (Eq. 6, Appendix J). The coefficients of this linear combination are specific to the OOD dataset and reflect its similarity to Dpt and Dept (Fig. 9, Fig. 21, Fig. 22). Once these coefficients are determined (e.g., by fitting to a few initial steps), predicting the OOD loss curve during CPT reduces to predicting the Dpt and Dept loss curves using the main scaling law, and optimizing OOD performance becomes equivalent to balancing Dpt and Dept losses using the derived coefficients (Finding 5).

Finally, the paper addresses the practical challenge of CPT using open-source pre-trained models where detailed training history (dataset distribution, exact training steps, loss potential) is often unknown. They propose simple methods to adapt their scaling law:

  1. Use a proxy general dataset (e.g., a common crawl subset like RedPajama-C4) as a stand-in for the unknown original Dpt validation set to measure general performance dynamics.
  2. Treat unknown PT parameters, such as the cumulative forward area (Spt), as parameters to be fitted from initial CPT steps.
  3. Assume the final PT LR is zero for models that are typically fully annealed for benchmarking. They demonstrate the effectiveness of these methods by fitting and predicting the CPT loss curve for LLaMA3.2-1B using RedPajama-C4 as a proxy Dpt and Pile-of-Law as Dept (Fig. 18, Appendix G).

In essence, this work provides a data-driven, empirical framework to understand, model, and predict the dynamic behavior of LLMs during CPT. It offers practical tools and insights for hyperparameter optimization to navigate the trade-off between maintaining general capabilities and acquiring domain-specific knowledge, even when working with open-source models with incomplete histories. While acknowledging the empirical nature of the laws due to the complexity of LLM training, the paper presents extensive experimental evidence supporting the applicability and effectiveness of their proposed scaling law in practical CPT scenarios.

List To Do Tasks Checklist Streamline Icon: https://streamlinehq.com

Collections

Sign up for free to add this paper to one or more collections.

X Twitter Logo Streamline Icon: https://streamlinehq.com

Tweets

This paper has been mentioned in 6 posts and received 18 likes.

Don't miss out on important new AI/ML research

See which papers are being discussed right now on X, Reddit, and more:

“Emergent Mind helps me see which AI papers have caught fire online.”

Philip

Philip

Creator, AI Explained on YouTube