LLM Pretraining with Continuous Concepts
(2502.08524v1)
Published 12 Feb 2025 in cs.LG and cs.CL
Abstract: Next token prediction has been the standard training objective used in LLM pretraining. Representations are learned as a result of optimizing for token-level perplexity. We propose Continuous Concept Mixing (CoCoMix), a novel pretraining framework that combines discrete next token prediction with continuous concepts. Specifically, CoCoMix predicts continuous concepts learned from a pretrained sparse autoencoder and mixes them into the model's hidden state by interleaving with token hidden representations. Through experiments on multiple benchmarks, including LLMing and downstream reasoning tasks, we show that CoCoMix is more sample efficient and consistently outperforms standard next token prediction, knowledge distillation and inserting pause tokens. We find that combining both concept learning and interleaving in an end-to-end framework is critical to performance gains. Furthermore, CoCoMix enhances interpretability and steerability by allowing direct inspection and modification of the predicted concept, offering a transparent way to guide the model's internal reasoning process.
Summary
The paper introduces CoCoMix, a pretraining method that augments token prediction with latent continuous concept modeling derived from a sparse autoencoder.
It employs a TopK mechanism with attribution scores to select key semantic concepts, boosting next token prediction in weak-to-strong supervision scenarios.
Empirical results demonstrate enhanced interpretability and steerability, with CoCoMix outperforming standard NTP and knowledge distillation baselines.
The paper introduces Continuous Concept Mixing (CoCoMix), a pretraining framework for LLMs that combines next token prediction with continuous concepts learned from a pretrained sparse autoencoder (SAE). The motivation stems from the limitations of relying solely on token-level perplexity for learning high-level reasoning and conceptual understanding. CoCoMix aims to bridge semantic abstraction and fine-grained token-level guidance by augmenting the next token prediction objective with explicit modeling of concepts in a latent representation space.
The paper details the CoCoMix methodology, which involves extracting semantic concepts using a pretrained SAE and selecting the most influential ones based on attribution scores. The model is trained to predict these selected concepts from its hidden state using a cross-entropy loss. The predicted concepts are then compressed into a single continuous concept vector, which is interleaved with token embeddings. The SAE decomposes the hidden state into multiple dimensions, each representing a distinct concept. The SAE uses a TopK activation function to enforce sparsity, isolating the most critical dimensions that explain the pretrained model's features. The reconstruction process of SAE is defined as:
htcon is the pretrained model's hidden state at position t,
E is a linear encoder mapping Rdcon to RC,
D is a linear decoder mapping RC to Rdcon,
C is the dimension of the concept space,
htpre is the pre-activation concept vector,
TopK(⋅) zeros out all but the largest KSAE entries,
htcon is the reconstruction.
The attribution score st measures the influence of each concept on the output, based on the local linear approximation of the effect of changing the concept value:
$f_{\mathtt{con}\big(x_{t+1}|D(h_{t}),x_{<t}\big)$ is the probability of predicting the next token xt+1 given the decoded concepts and previous tokens.
The indices of the concept that have a high attribution score are selected as discrete labels for concept prediction. A linear prediction head M outputs logit lt=M(ht)∈RC, where ht is the model's hidden state. The cross-entropy loss Lconcept is defined as:
Lconcept(ht)=Kattr1i∈I∑CE(lt,i),
where
I is the set of indices corresponding to the top Kattr values of st,
CE is the cross-entropy.
The concept prediction logit lt is sparsified using TopK activation and compressed into a continuous concept vector h^t∈Rd:
h^t=TopK(lt)W+b,
where
W∈Rd×C and b∈Rd project the TopK-sparse vector to a d-dimensional embedding.
The final training objective combines the standard next token prediction loss and the concept prediction term:
t=1∑T−1−logf(xt+1∣x≤t,h^≤t)+λLconcept(ht),
where
λ is a tunable coefficient.
The paper presents an empirical evaluation of CoCoMix, examining its performance on next token prediction, weak-to-strong supervision, interpretability, and steerability. The training setup involves using a pretrained open-source SAE trained on the 124M-sized GPT-2. CoCoMix is trained with varying parameter sizes (68M, 386M, and 1.38B) and a context length of 1024. The OpenWebText dataset is used as the pretraining corpus. Baselines include the standard next token prediction (NTP) procedure and knowledge distillation (KD).
The results demonstrate that CoCoMix improves the performance of next token prediction, particularly in weak-to-strong supervision scenarios. CoCoMix achieves comparable performance to NTP with fewer training tokens and shows improvements in downstream tasks. In weak-to-strong supervision, concepts extracted from a smaller model are used to supervise the training of a larger model. CoCoMix also enhances interpretability and steerability, allowing for the analysis and control of the model's output generation. Additionally, the paper analyzes the effectiveness of each component of CoCoMix, including the attribution score, concept prediction, and mixing.
The paper compares CoCoMix with KD across multiple scenarios, including a stronger teacher model teaching a smaller student model, weak-to-strong supervision, and distribution shift. CoCoMix demonstrates improvements over KD in all model configurations, particularly in weak-to-strong supervision. A weight analysis of the compression layer reveals that CoCoMix learns to ignore ineffective concepts. Both concept prediction and concept insertion are critical for performance improvement. Comparing concept conditioning methods, the insertion method, which interleaves the continuous concept, performs better than intervention, which adds the concept vector to the hidden state. CoCoMix also outperforms pause tokens, indicating that the inserted continuous concepts contain useful information.