Papers
Topics
Authors
Recent
Search
2000 character limit reached

Continual Pre-training of Language Models

Published 7 Feb 2023 in cs.CL, cs.AI, cs.LG, and cs.NE | (2302.03241v4)

Abstract: LLMs (LMs) have been instrumental for the rapid advance of natural language processing. This paper studies continual pre-training of LMs, in particular, continual domain-adaptive pre-training (or continual DAP-training). Existing research has shown that further pre-training an LM using a domain corpus to adapt the LM to the domain can improve the end-task performance in the domain. This paper proposes a novel method to continually DAP-train an LM with a sequence of unlabeled domain corpora to adapt the LM to these domains to improve their end-task performances. The key novelty of our method is a soft-masking mechanism that directly controls the update to the LM. A novel proxy is also proposed to preserve the general knowledge in the original LM. Additionally, it contrasts the representations of the previously learned domain knowledge (including the general knowledge in the pre-trained LM) and the knowledge from the current full network to achieve knowledge integration. The method not only overcomes catastrophic forgetting, but also achieves knowledge transfer to improve end-task performances. Empirical evaluation demonstrates the effectiveness of the proposed method.

Citations (93)

Summary

  • The paper presents DAS, a method for continual domain-adaptive pre-training that preserves general language knowledge while enabling knowledge transfer with soft-masking and contrastive loss.
  • The method computes unit importance via gradient-based proxy scores and applies soft-masking to mitigate catastrophic forgetting in transformer architectures.
  • Experimental results demonstrate that DAS outperforms existing approaches by achieving negative forgetting rates and superior performance across multiple domain tasks.

Continual Pre-training of LLMs

This paper introduces a novel approach, DAS (Continual DA-pre-training of LMs with Soft-masking), for continual domain-adaptive pre-training (DAP-training) of LMs. The method addresses catastrophic forgetting (CF) and encourages knowledge transfer (KT) across domains by employing a soft-masking mechanism and a contrastive learning approach.

DAS Methodology

The DAS technique comprises two primary components: preserving general language knowledge and learned domain knowledge via soft-masking, and promoting complementary representations of current and previous domains for knowledge integration. Figure 1

Figure 1: Illustration of DAS, highlighting the initialization, domain training, and importance computation steps.

The overall learning process is structured into initialization and continual learning phases. Initialization computes the importance of units for general language knowledge. The continual learning phase consists of domain training and importance computation, using accumulated importance scores and current domain data.

Initialization: Computing Importance of Units

This phase computes the importance of units (attention heads and neurons) within the Transformer for general knowledge present in the original LM. The importance of units in a layer is computed using a virtual parameter, gl\bm{g}_{l}, and a gradient-based proxy score:

$\bm{I}_{l} = \frac{1}{N}\sum_{n=1}^N|\frac{\partial\mathcal{L}_{\text{impt}(\bm{x}_n,{y}_n))}{\partial \bm{g}_{l}|$

To address the lack of pre-training data, a proxy KL-divergence loss ($\mathcal{L}_{\text{proxy}$), based on model robustness, is proposed:

$\mathcal{L}_{\text{impt} = \text{KL}(f^1_{\text{LM}(\bm{x}^{\text{sub}_n),f^2_{\text{LM}(\bm{x}^{\text{sub}_n)),$

Training: Soft-masking and Contrastive Loss

During DAP-training, DAS preserves learned knowledge using accumulated importance Il(≤t−1)\bm{I}_{l}^{(\le t-1)}, achieved by soft-masking the learning based on accumulated importance:

Il(≤t−1)=EMax({Il(t−1),Il(≤t−2)}).\bm{I}_{l}^{(\le t-1)} = \text{EMax}(\{\bm{I}_{l}^{(t-1)},\bm{I}^{(\le t-2)}_{l}\}).

Soft-masking units involves constraining the gradient flow using the accumulated importance value:

∇^l=(1−Il(≤t−1))⊗∇l.\hat{\nabla}_{l} = (1-\bm{I}_{l}^{(\le t-1)}) \otimes \nabla_{l}.

DAS integrates previously learned knowledge with current domain knowledge by contrasting learned and full knowledge. The contrastive loss is:

$\mathcal{L}_{\text{contrast} = -\frac{1}{N}\sum_{n=1}^{N}\text{log}\frac{e^{\text{sim}(\bm{o}_n^{\text{full},\bm{o}_n^{\text{full+})}/\tau}{\sum_{j=1}^{N}(e^{\text{sim}(\bm{o}_n^{\text{full},\bm{o}_j^{\text{full+})/\tau}+e^{\text{sim}(\bm{o}^{\text{full}_n,\bm{o}_j^{\text{prev})/\tau})}.$

The final loss function combines the MLM loss and the contrastive loss:

$\mathcal{L}_{\text{DAP-train} = \mathcal{L}_{\text{MLM} + \lambda\mathcal{L}_{\text{contrast}.$

Compute Importance of Units for the Current Domain

After training a new domain tt, the importance of units is learned by applying Eq.~\ref{eq:importance} for the domain, using the MLM loss. The resulting Il(t)\bm{I}^{(t)}_{l} is used in the next task by accumulating with the previously accumulated importance and soft-masking the learning.

Experimental Results

The paper evaluates DAS using six unlabeled domain corpora and their corresponding end-task classification datasets. The baselines include non-continual learning (Non-CL) and continual learning (CL) methods. The results indicate that DAS outperforms all baselines and achieves the best knowledge transfer, demonstrated by a negative forgetting rate. DAS is slightly better than Pool on average, and achieves both forgetting prevention and knowledge transfer. Directly learning the domains within the LM helps DAS achieve better results than adapter and prompt based methods, and using the full LM to learn all tasks makes DAS more effective than using sub-networks. The effectiveness of the proxy KL-divergence loss is demonstrated by comparing it with a sample set of D0D_0 and by comparing general knowledge computed from different domain corpora. Ablation studies validate the contribution of initialization, soft-masking, and contrastive learning.

Conclusion

The paper introduces DAS, a method for continual DAP-training of LMs that incorporates soft-masking and contrastive learning. The key ideas involve preserving previous knowledge, using a novel proxy to compute unit importance, and learning complementary representations. The method achieves gains in both forgetting prevention and knowledge transfer.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 3 tweets with 19 likes about this paper.