Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
110 tokens/sec
GPT-4o
56 tokens/sec
Gemini 2.5 Pro Pro
44 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Neural Collapse Under MSE Loss: Proximity to and Dynamics on the Central Path (2106.02073v4)

Published 3 Jun 2021 in cs.LG, cs.AI, math.DG, math.OC, and stat.ML

Abstract: The recently discovered Neural Collapse (NC) phenomenon occurs pervasively in today's deep net training paradigm of driving cross-entropy (CE) loss towards zero. During NC, last-layer features collapse to their class-means, both classifiers and class-means collapse to the same Simplex Equiangular Tight Frame, and classifier behavior collapses to the nearest-class-mean decision rule. Recent works demonstrated that deep nets trained with mean squared error (MSE) loss perform comparably to those trained with CE. As a preliminary, we empirically establish that NC emerges in such MSE-trained deep nets as well through experiments on three canonical networks and five benchmark datasets. We provide, in a Google Colab notebook, PyTorch code for reproducing MSE-NC and CE-NC: at https://colab.research.google.com/github/neuralcollapse/neuralcollapse/blob/main/neuralcollapse.ipynb. The analytically-tractable MSE loss offers more mathematical opportunities than the hard-to-analyze CE loss, inspiring us to leverage MSE loss towards the theoretical investigation of NC. We develop three main contributions: (I) We show a new decomposition of the MSE loss into (A) terms directly interpretable through the lens of NC and which assume the last-layer classifier is exactly the least-squares classifier; and (B) a term capturing the deviation from this least-squares classifier. (II) We exhibit experiments on canonical datasets and networks demonstrating that term-(B) is negligible during training. This motivates us to introduce a new theoretical construct: the central path, where the linear classifier stays MSE-optimal for feature activations throughout the dynamics. (III) By studying renormalized gradient flow along the central path, we derive exact dynamics that predict NC.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (3)
  1. X. Y. Han (6 papers)
  2. Vardan Papyan (26 papers)
  3. David L. Donoho (25 papers)
Citations (123)

Summary

  • The paper demonstrates that Neural Collapse emerges in MSE-trained networks via a novel loss decomposition into least-squares and deviation terms.
  • It reveals that network classifiers follow a central path where feature collapse and convergence to a Simplex ETF structure occur during the terminal phase of training.
  • Empirical results on datasets like MNIST and CIFAR10 show that MSE loss achieves comparable accuracy to cross-entropy and can enhance adversarial robustness.

This paper, "Neural Collapse Under MSE Loss: Proximity to and Dynamics on the Central Path" (Han et al., 2021 ), investigates the Neural Collapse (NC) phenomenon in deep neural networks trained with Mean Squared Error (MSE) loss, contrasting with the more common Cross-Entropy (CE) loss. NC is a set of empirical observations made during the final stage of deep network training (the Terminal Phase of Training or TPT), characterized by four properties:

  1. NC1 (Within-class variability collapse): Last-layer features for examples belonging to the same class converge to their class mean, effectively reducing within-class variance.
  2. NC2 (Convergence to Simplex ETF): The class means of the features, and the last-layer classifiers, converge to the vertices of a Simplex Equiangular Tight Frame (ETF). This is a highly symmetric geometric configuration where vectors have equal norms and equal pairwise angles.
  3. NC3 (Convergence to self-duality): The direction of the last-layer classifier for a class aligns with the direction of its class mean (relative to the global mean).
  4. NC4 (Simplification to nearest class center): The decision rule of the network converges to simply classifying an input based on which class mean its feature vector is closest to in Euclidean distance.

While NC was initially observed with CE loss, this paper empirically demonstrates that NC also emerges in deep networks trained with MSE loss on standard datasets (MNIST, FashionMNIST, CIFAR10, SVHN, STL10) and architectures (VGG, ResNet, DenseNet). The paper provides code to reproduce these findings, which is a crucial practical contribution. The empirical results show that MSE-trained networks achieve comparable test accuracy to CE-trained networks and also exhibit similar NC properties, sometimes collapsing faster (NC1) or achieving better adversarial robustness.

The core theoretical contribution is a new decomposition of the MSE loss tailored for classification: L=LLS+LDevL = L_{\text{LS}} + L_{\text{Dev}}.

  • LLSL_{\text{LS}}: This term represents the minimum possible MSE loss achievable if the last-layer classifier were exactly the least-squares optimal classifier given the current features.
  • LDevL_{\text{Dev}}: This term measures the deviation of the actual classifier from this least-squares optimum.

Empirical measurements presented in the paper show that during training, LDevL_{\text{Dev}} quickly becomes negligible compared to LLSL_{\text{LS}}. This observation motivates the concept of the central path, a theoretical construct where the classifier is always precisely the least-squares optimal classifier for the current features. The empirical evidence suggests that real-world networks, during TPT, closely follow this central path. This simplifies the analysis, as one can focus on minimizing LLSL_{\text{LS}} with respect to features, assuming the classifier is implicitly determined.

The LLSL_{\text{LS}} term is further decomposed into two components with intuitive interpretations:

  • LNC1L_{\text{NC1}}: Related to within-class variance, this term encourages feature collapse (NC1). Minimizing this term, for fixed class means, drives features towards their class means.
  • LNC2+NC3L_{\text{NC2+NC3}}: Related to the configuration of class means and classifiers, this term encourages convergence towards the Simplex ETF structure (NC2 and NC3).

The paper then investigates the dynamics of features on this central path, focusing on the zero-global-mean setting (motivated by the properties of the least-squares classifier). They discover an invariance property: the class predictions and MSE loss on the central path are invariant under certain linear transformations of the features. This leads to the idea of analyzing renormalized features, where features are transformed such that their within-class covariance is the identity matrix (a process conceptually related to whitening or batch normalization).

Analyzing the continually renormalized gradient flow on these features, the paper derives exact, closed-form dynamics for the singular values of the Signal-to-Noise Ratio (SNR) matrix. The SNR matrix, defined as the within-class covariance pseudoinverse multiplied by the between-class covariance (1^{-1}), naturally arises in the least-squares classifier formulation and captures the separation of class means relative to within-class noise. The derived dynamics for the SNR singular values take the form:

c1log(ωj(t))+c2ωj2(t)+c3ωj4(t)=aj+tc_1 \log(\omega_j(t)) + c_2 \omega_j^2(t) + c_3 \omega_j^4(t) = a_j + t

where ωj(t)\omega_j(t) are the singular values, tt is training time, and c1,c2,c3,ajc_1, c_2, c_3, a_j are constants.

The solution to this ordinary differential equation shows that as tt \to \infty:

  1. All non-zero singular values ωj(t)\omega_j(t) grow infinitely large at a rate proportional to t1/4t^{1/4}.
  2. The ratio between the maximum and minimum singular values converges to 1, meaning all non-zero singular values converge to equality.

These theoretical dynamics directly imply the emergence of Neural Collapse:

  • Infinitely large singular values mean the "signal" (separation between class means) is infinitely large relative to the "noise" (within-class variance), which corresponds to NC1 (within-class variability collapse).
  • Singular values converging to equality, combined with the structure of the SNR matrix, implies that the renormalized class means converge to a Simplex ETF configuration, corresponding to NC2.
  • As shown in prior work and re-established here for the central path, NC1 and NC2 together imply NC3 (self-duality) and NC4 (nearest class center rule).

Practical Implications for Developers and Practitioners:

  1. MSE as a Classifier Loss: The empirical results confirm that MSE loss can be a viable alternative to CE for multi-class classification, achieving comparable accuracy. In some settings, like adversarial robustness (shown in the appendix), MSE might even offer advantages. This suggests MSE is a practical option to consider, especially given its analytical tractability explored in the paper.
  2. Understanding Training Dynamics: The loss decomposition (LLS,LDev,LNC1,LNC2+NC3L_{\text{LS}}, L_{\text{Dev}}, L_{\text{NC1}}, L_{\text{NC2+NC3}}) provides a framework for analyzing what the network is optimizing at different stages. Observing these components during training can offer insights into the training process – for instance, confirming whether the classifier stays close to the LS optimum (LDevL_{\text{Dev}} negligible) or identifying which NC property is being minimized fastest (LNC1L_{\text{NC1}} vs LNC2+NC3L_{\text{NC2+NC3}}).
  3. Theoretical Model for Late-Stage Training: The concept of the central path and the derived dynamics provide a simplified theoretical model for understanding the behavior of deep networks during TPT. While based on assumptions (continuous flow, zero global mean, specific renormalization), this model offers a tractable explanation for NC, a complex empirical phenomenon. This could potentially guide the development of better optimization algorithms or regularization techniques aimed at achieving beneficial late-stage properties.
  4. Relevance of Normalization: The theoretical analysis highlights the role of feature normalization (like the explicit renormalization studied, related to Batch Norm) in shaping the geometric structure of features. The findings connect the idea of whitening features to the dynamics that lead to the Simplex ETF structure and feature collapse.
  5. Reproducibility: The authors provide PyTorch code and data on Google Colab and Stanford Digital Repository, making it straightforward for practitioners to reproduce their empirical findings on MSE-NC and the loss decomposition. This allows others to verify the phenomenon on their own datasets and architectures.

Implementation Considerations:

  • Monitoring Loss Components: Implementing the calculation of LLSL_{\text{LS}} and LDevL_{\text{Dev}} (requiring computing the least-squares classifier at each step) during training adds computational overhead, but it can be a valuable diagnostic tool.
  • Computational Resources: The experimental appendix details the use of GPU clusters for training, indicating standard deep learning resource requirements.
  • Limitations: The theoretical analysis on dynamics makes simplifying assumptions (continuous gradient flow, zero global mean, ignoring weight decay). While motivated empirically, these are approximations of real SGD training with biases and weight decay. The empirical results also show some dataset/architecture combinations (like STL10-ResNet/DenseNet) exhibit outlier behavior, suggesting the theory might not universally apply or requires adjustments for certain problems. The analysis of NC on test data (shown in the appendix) indicates slower collapse compared to training data, suggesting the connection between NC and generalization is complex and an area for further research.

In summary, this paper validates that Neural Collapse is a phenomenon not exclusive to CE loss but also occurs under MSE loss. By carefully decomposing the MSE loss and analyzing feature dynamics on an empirically motivated "central path", the paper provides a tractable theoretical framework that derives the core NC properties from first principles, offering valuable insights into the geometric structure that emerges during the late stages of deep network training. The provided code enhances the practical utility by enabling reproduction and further exploration.