- The paper demonstrates that key neural collapse properties emerge with model scaling, leading to reduced within-class variability and improved geometric uniformity.
- The methodology employs Transformer-based CLMs with varied widths, depths, and weight decay on the TinyStories dataset, using metrics like CDNV and cosine similarity.
- The study finds significant correlations between neural collapse metrics and generalization, highlighting their potential role in enhancing model performance independent of scaling.
This paper, "Linguistic Collapse: Neural Collapse in (Large) LLMs" (Linguistic Collapse: Neural Collapse in (Large) Language Models, 28 May 2024), investigates the phenomenon of Neural Collapse (NC) in Causal LLMs (CLMs). Neural Collapse is a set of behaviors observed in deep neural networks trained for classification tasks, particularly during the terminal phase of training towards zero loss on balanced, noise-free data where the number of classes does not significantly exceed the embedding dimension. The key properties of NC traditionally include:
- (NC1) Within-class variability collapse: Top-layer representations for inputs from the same class converge to their class mean.
- (NC2) Convergence to a simplex ETF: The class means, when centered, tend towards an equinorm and equiangular configuration (Simplex Equiangular Tight Frame).
- (NC3) Convergence to self-duality: Top-layer classifiers align with their corresponding class means.
- (NC4) Nearest decision rule: The standard linear classifier becomes equivalent to a Nearest-Class Center (NCC) classifier.
The authors note that training LLMs via next-token prediction is essentially a classification task over the vocabulary. However, the conditions under which LMs are typically trained starkly contrast with those traditionally favoring NC:
- Many classes (¬C1): The vocabulary size (C, tens of thousands) is much larger than the embedding dimension (d). A perfect simplex ETF requires C≤d+1.
- Imbalanced classes (¬C2): Token distributions in natural language are highly imbalanced.
- Ambiguous contexts (¬C3): Similar contexts can lead to different valid next tokens (e.g., "Once upon a time" followed by "," or " in").
- Undertraining (¬C4): LMs, especially large ones, are often not trained to full convergence or past zero error/loss in practice.
Given these conflicting conditions, the paper empirically investigates whether NC properties emerge in CLMs despite these challenges and how they relate to model scaling and generalization.
Empirical Investigation and Methodology
The paper trains a suite of Transformer-based CLMs (similar to GPT-Neo) on the TinyStories dataset (Revisiting the Alpha Algorithm To Enable Real-Life Process Discovery Applications -- Extended Report, 2023). They vary model width (d∈{64,128,…,1024}), depth (L∈{1,2,…,12}), and training epochs (1, 3, 10). They also experiment with different weight decay factors.
For each trained model, the authors collect top-layer context embeddings for validation data and the model's linear classifiers (the final output layer weights). They then compute metrics to quantify the degree of NC based on these embeddings and classifiers, adapting some metrics for the LM context:
- NC1 (Within-Class Variability): Measured using the average Class-Distance Normalized Variance (CDNV) across token classes. Lower CDNV indicates less within-class variability relative to between-class distance.
- GNC2 (Geometric Structure): Beyond traditional Equinormness (CoV of mean norms) and Equiangularity (CoV of pairwise interference), they specifically measure Hyperspherical Uniformity (GNC2) using variation in pairwise logarithmic distances between normalized class means. This is motivated by prior work on generalized NC when C>d+1.
- UNC3 (Duality): Instead of just measuring the difference between normalized class means and classifiers (NC3, self-duality), they calculate the cosine similarity between each normalized class mean and its corresponding classifier vector. They introduce Uniform Duality (UNC3) as the minimization of the Coefficient of Variation (CoV) of these similarities, indicating a more consistent alignment across classes.
- NC4 (Classifier Agreement): Calculated as the proportion of validation samples where the linear classifier's prediction matches that of an implicit Nearest-Class Center (NCC) classifier based on the learned class means.
Generalization is measured by validation loss (next-token prediction cross-entropy).
To investigate the relationship between NC and generalization independent of scale, they train multiple instances of a single architecture (2-layer, 768-wide) with different random seeds for data shuffling and initialization, then perform a permutation test on the correlation between NC metrics and validation loss.
Key Findings and Practical Implications
The empirical results reveal several key insights:
- Emergence of NC with Scaling: Despite the challenging conditions, several NC properties emerge or strengthen as model size (width and depth) and training epochs increase.
- NC1 (CDNV) consistently decreases with scale and training, indicating reduced within-class variability.
- Mean embedding norms grow, and their variation (Equinormness, part of NC2) decreases with scale.
- Average interference decreases, but variation in interference (Equiangularity, traditional NC2) doesn't consistently decrease with scale, supporting the idea that a perfect simplex ETF is not formed when C≫d+1.
- Hyperspherical Uniformity (GNC2, variation in logarithmic distances) consistently improves with scale and training, confirming its relevance in this setting.
- Average similarity between class means and classifiers (NC3, self-duality) shows weak trends with scale, but variation in similarity (UNC3, uniform duality) decreases with width and training.
- NC4 (Classifier Agreement) improves significantly with scale and training.
- Correlation with Generalization: The observed developments in NC properties are strongly correlated with improved validation performance (lower validation loss).
- NC1, GNC2, UNC3, and NC4 show notable correlations with generalization.
- Traditional NC2 (Equiangularity) and NC3 (Self-Duality) show weaker correlations compared to their generalized/uniform counterparts in this LM setting.
- NC and Generalization Independent of Scale: The permutation test on models with identical architecture but different random seeds reveals that several NC properties (NC1, GNC2, NC3, NC4, and traditional NC2 Equiangularity) are statistically significantly correlated with generalization performance even when scale and training time are fixed. This suggests that NC is not merely a side effect of scaling and training, but potentially a more fundamental aspect of model performance and generalization in LMs.
- Weight Decay: Stronger weight decay appeared to promote the development of NC properties.
Implementation Considerations and Future Work
This research is primarily an empirical analysis rather than proposing a new implementation technique. However, the methodology suggests practical ways to analyze the feature space and classifiers of existing or newly trained LMs:
- Monitoring NC Metrics: Developers can implement the described metrics (CDNV for NC1, log-distance variation for GNC2, similarity CoV for UNC3, classifier agreement for NC4) during or after training to gain insights into the model's feature learning and its potential for generalization.
- Feature Space Analysis: The metrics provide low-level interpretability by quantifying aspects like class separability (NC1), the geometric arrangement of classes (GNC2), and the consistency of classifier alignment (UNC3). This could help diagnose issues like poor separation for specific tokens or groups of tokens.
- Potential for New Objectives: The findings could inspire research into training objectives that explicitly encourage certain NC properties in LMs, similar to how feature regularization is used in imbalanced image classification. For instance, adding terms to the loss that penalize high CDNV or high CoV of log-distances might promote better generalization.
- Understanding Ambiguity and Compression: The persistent noise (NC1) due to ambiguous contexts might relate to LLMs' ability to model aleatoric uncertainty or their function as data compression systems, as suggested by the authors.
Limitations
The authors acknowledge limitations, including that the chosen NC metrics might not be perfectly suited for all aspects of LLMing collapse. The paper focuses on basic causal LLMing and does not include experiments on more complex settings like encoder-decoder models, multi-modal models, or instruction-tuned models. The scale-independent correlation analysis was performed only on a single, relatively small architecture, and results might not directly translate to much larger models.
In summary, the paper successfully adapts the Neural Collapse framework to the challenging domain of LLMing, providing empirical evidence that NC properties emerge with scale and training and are correlated with generalization, even independent of scale. This work lays the groundwork for deeper understanding and potentially improved architectures for LMs based on NC-related insights.