Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
126 tokens/sec
GPT-4o
47 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Neural Collapse for Unconstrained Feature Model under Cross-entropy Loss with Imbalanced Data (2309.09725v2)

Published 18 Sep 2023 in stat.ML, cs.LG, and math.OC

Abstract: Recent years have witnessed the huge success of deep neural networks (DNNs) in various tasks of computer vision and text processing. Interestingly, these DNNs with massive number of parameters share similar structural properties on their feature representation and last-layer classifier at terminal phase of training (TPT). Specifically, if the training data are balanced (each class shares the same number of samples), it is observed that the feature vectors of samples from the same class converge to their corresponding in-class mean features and their pairwise angles are the same. This fascinating phenomenon is known as Neural Collapse (N C), first termed by Papyan, Han, and Donoho in 2019. Many recent works manage to theoretically explain this phenomenon by adopting so-called unconstrained feature model (UFM). In this paper, we study the extension of N C phenomenon to the imbalanced data under cross-entropy loss function in the context of unconstrained feature model. Our contribution is multi-fold compared with the state-of-the-art results: (a) we show that the feature vectors exhibit collapse phenomenon, i.e., the features within the same class collapse to the same mean vector; (b) the mean feature vectors no longer form an equiangular tight frame. Instead, their pairwise angles depend on the sample size; (c) we also precisely characterize the sharp threshold on which the minority collapse (the feature vectors of the minority groups collapse to one single vector) will take place; (d) finally, we argue that the effect of the imbalance in datasize diminishes as the sample size grows. Our results provide a complete picture of the N C under the cross-entropy loss for the imbalanced data. Numerical experiments confirm our theoretical analysis.

Citations (10)

Summary

  • The paper proves that within-class feature collapse persists under imbalanced data, ensuring that samples converge to a common class mean.
  • The paper shows that imbalanced data breaks the equiangular structure of class means, forming block patterns that underlie bias toward majority classes.
  • The paper establishes a precise threshold for minority collapse and demonstrates that increasing overall data can asymptotically recover the symmetric ETF structure.

Neural Collapse (NC) is an intriguing phenomenon observed in deep neural networks (DNNs) towards the end of training (Terminal Phase of Training, TPT), particularly in classification tasks. It describes how features extracted by the network collapse within each class (all samples from the same class have very similar features), and how the mean features of different classes arrange themselves in a highly structured way, such as forming an Equiangular Tight Frame (ETF) in the case of balanced data. This paper (Neural Collapse for Unconstrained Feature Model under Cross-entropy Loss with Imbalanced Data, 2023) extends the paper of NC to the more realistic scenario of imbalanced training data under the widely used cross-entropy (CE) loss function, using the theoretical framework of the Unconstrained Feature Model (UFM).

The Unconstrained Feature Model simplifies the analysis of complex DNNs by treating the features of the training data as free variables to be optimized, alongside the last layer's weights and bias. The rationale is that sufficiently overparameterized DNNs can approximate any function, including arbitrary feature maps. The paper formulates this as a convex optimization problem (Equation 2.3), which involves minimizing the CE loss plus nuclear norm regularization on the prediction matrix (Z=WHZ=W^\top H) and L2 regularization on the bias vector (bb). The nuclear norm regularization encourages low-rank solutions for the prediction matrix, linking the feature and weight matrices (WW and HH) to the singular value decomposition (SVD) of ZZ. Specifically, the paper leverages the fact that minWH=ZλW2WF2+λH2HF2=λWλHZ\min_{W^\top H = Z} \frac{\lambda_W}{2}\|W\|_F^2 + \frac{\lambda_H}{2}\|H\|_F^2 = \sqrt{\lambda_W\lambda_H}\|Z\|_*. By setting λZ=λWλH\lambda_Z = \sqrt{\lambda_W\lambda_H}, the UFM problem becomes convex and more tractable.

Here are the key findings and their practical implications:

  1. Within-Class Collapse (NC1\mathcal{NC}_1) Persists: Even with imbalanced data, the paper theoretically proves that the feature vectors for all samples within the same class converge to a single mean feature vector (Theorem 3.1a). This means the first aspect of Neural Collapse, the collapse of within-class variability, still holds.
    • Practical Implication: This confirms that standard CE training, even on imbalanced data, drives the network to learn class-specific feature representations where all instances of a class cluster tightly in the feature space. This is a desirable property for classification, making classes separable.
  2. Loss of Equiangularity (NC2_2 and NC3_3 modified by imbalance): Unlike the balanced case where mean features form an ETF (equal pairwise angles and lengths), imbalanced data breaks this symmetry. The paper shows that the mean prediction vectors and bias terms exhibit a block structure based on class sample sizes (Theorem 3.1b). Classes with the same number of samples form "clusters," and within each cluster, the mean feature/prediction vectors share similar pairwise angles and magnitudes, but the angles and magnitudes differ between clusters.
    • Practical Implication: This explains why standard CE training on imbalanced data often leads to biased performance, favoring majority classes. The geometry of the learned feature space is distorted by the data distribution. The decision boundaries implicitly learned by the classifier will be influenced by these class-dependent structures, potentially making separation of minority classes more difficult. The weight matrix WW and mean feature matrix Hˉ\bar{H} inherit this block structure (Theorem 3.1d).
  3. Sharp Minority Collapse Threshold: The paper provides a precise theoretical threshold for a phenomenon called "minority collapse." This occurs when the mean prediction vectors for all classes within a minority cluster collapse to a single vector. In this state, the model effectively learns to predict the same output for all minority classes, making them indistinguishable. The threshold depends on the regularization parameter λZ\lambda_Z, the number of classes in each cluster (kA,kBk_A, k_B), and their respective sample sizes (nA,nBn_A, n_B) (Theorem 3.2 and Corollary 3.1).
    • Practical Implication: This is a critical negative finding. It provides a specific condition under which standard CE training on imbalanced data fails for minority classes. The derived threshold (e.g., r1kA(1nBλZkB)r \geq \frac{1}{k_A}(\frac{1}{\sqrt{n_B}\lambda_Z} -k_B) for minority collapse based on imbalance ratio r=nA/nBr=n_A/n_B) gives practitioners a way to predict when this failure will occur based on dataset characteristics and hyperparameter settings. This knowledge is vital for diagnosing issues in imbalanced classification and for guiding mitigation strategies. For instance, if the imbalance ratio rr is high or the nuclear norm regularization λZ\lambda_Z is too large relative to the minority sample size nBn_B, minority collapse is likely.
  4. Asymptotic Behavior Approaching ETF: The paper shows that if the imbalance ratio (r=nA/nBr=n_A/n_B) is fixed but the total sample size (NN) goes to infinity (i.e., both nAn_A and nBn_B increase), the mean prediction vectors asymptotically converge towards the symmetric ETF structure seen in the balanced case (Theorem 3.3). The effect of imbalance diminishes as data scale increases.
    • Practical Implication: This offers a hopeful perspective. While imbalance is detrimental at smaller scales, collecting significantly more data for all classes (including minority ones, while maintaining the ratio) can help the learned feature space geometry recover the desirable symmetric structure, potentially leading to better generalization, especially on minority classes.
  5. Benign Optimization Landscape: The paper proves that the optimization problem for the UFM (Equation 2.1), which is non-convex, still possesses a "benign" landscape, meaning all local minima are also global minima (Theorem 3.4).
    • Practical Implication: While this applies strictly to the UFM, it suggests that if a real DNN behaves similarly to the UFM in the TPT regime, standard gradient-based optimization algorithms (like SGD) are likely to find the global optimum of the UFM objective function, corresponding to the theoretically characterized collapse state. This property contributes to the predictability of the NC phenomenon in this model.

Implementation Considerations:

  • UFM as an Analytical Tool: The UFM is primarily a theoretical framework. You would typically not implement the UFM directly for training a classification system. Instead, the insights gained from analyzing the UFM (like the block structure and minority collapse threshold) inform your understanding and strategy for training actual DNNs on imbalanced data.
  • DNN Training: To validate these findings empirically, one trains standard DNN architectures (e.g., ResNet, VGG) using libraries like PyTorch or TensorFlow. The training process involves minimizing the CE loss on the imbalanced dataset using optimizers like SGD, often with weight decay (L2 regularization). The paper's experiments use specific settings for learning rates, momentum, and regularization parameters (λW,λH,λb\lambda_W, \lambda_H, \lambda_b) to align with the UFM formulation's assumptions.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    
    # Example (PyTorch-like pseudocode)
    import torch
    import torch.nn as nn
    import torch.optim as optim
    
    # Assume model is a standard DNN classification model
    model = ResNet18(num_classes=K)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    
    # Assuming you have an imbalanced data loader `imbalanced_train_loader`
    for epoch in range(num_epochs):
        for inputs, labels in imbalanced_train_loader:
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
    
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
        # Adjust learning rate, evaluate metrics, etc.
  • Measuring Collapse: To observe the NC phenomenon empirically in a trained DNN, you would need to extract features from the penultimate layer (hkih_{ki}) and the output logits (zki=Whki+bz_{ki} = W^\top h_{ki} + b). Then, compute the class-wise mean features (hˉk\bar{h}_k) and mean predictions (zˉk\bar{z}_k). Metrics like the NC1\mathcal{NC}_1 ratio (Equation 4.1) or visualizing the matrix of pairwise correlations between mean prediction vectors (Zˉ\bar{Z}) or mean feature vectors (Hˉ\bar{H}) can demonstrate collapse and the block structure.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    
    # Example (Conceptual PyTorch code to extract features and predictions)
    # Assuming `model` is trained, and you have access to data and labels
    features = {}
    predictions = {}
    model.eval() # Set model to evaluation mode
    with torch.no_grad():
        for inputs, labels in train_loader: # Use the training data
            # Assuming model has a method to get features before the last layer
            feats = model.get_features(inputs)
            logits = model(inputs) # Predictions Z_ki + b
    
            for i in range(inputs.size(0)):
                label = labels[i].item()
                if label not in features:
                    features[label] = []
                    predictions[label] = []
                features[label].append(feats[i].cpu().numpy())
                predictions[label].append(logits[i].cpu().numpy())
    
    # Compute mean features/predictions and analyze structure
    mean_features = {label: np.mean(f_list, axis=0) for label, f_list in features.items()}
    mean_predictions = {label: np.mean(p_list, axis=0) for label, p_list in predictions.items()}
    
    # Construct mean prediction matrix Z_bar and analyze (e.g., correlation matrix)
    Z_bar_matrix = np.stack([mean_predictions[k] for k in sorted(mean_predictions.keys())], axis=1)
    correlation_matrix = np.corrcoef(Z_bar_matrix.T) # Visualize this
  • Resource Requirements: Training DNNs on imbalanced data requires standard GPU resources depending on model size and dataset scale. Computing NC metrics post-training is less resource-intensive.
  • Deployment: The UFM analysis itself doesn't change the deployment of the trained DNN. However, awareness of potential minority collapse based on the threshold allows practitioners to anticipate poor performance on minority classes in deployment and potentially implement class-aware post-processing or re-train with mitigation strategies.

Trade-offs and Limitations:

  • UFM vs. Real DNNs: The primary trade-off is using a simplified UFM for theoretical tractability. While experiments show correspondence, real DNNs have complexities (nonlinearity, finite capacity, specific architectures) not captured by the UFM. The paper's numerical results show that while the minority collapse threshold is predicted well by the UFM, the complete collapse threshold (where all classes collapse) is overestimated by the UFM theory compared to empirical DNN behavior.
  • Terminal Phase: The analysis focuses on the state after training loss is nearly zero. Training dynamics and behavior before this phase are not directly addressed by the UFM analysis.
  • Specific Imbalance: The detailed threshold analysis is given for a dataset structure with classes grouped into distinct clusters by sample size. Applying the exact formulas derived might be less straightforward for more complex, continuous variations in class sizes.
  • Generalization: The paper focuses on the collapse phenomenon itself, not a direct proof of its impact on test generalization error in the imbalanced setting. However, minority collapse (indistinguishable predictions for minority classes) is clearly detrimental to generalization on those classes.

In summary, this paper provides valuable theoretical insights into how imbalanced data affects Neural Collapse under standard CE loss. It rigorously shows that within-class collapse persists but the desirable symmetric structure of mean features is broken, replaced by a block structure reflecting sample counts. Crucially, it provides a sharp, practically useful threshold for the onset of minority collapse, a severe failure mode in imbalanced learning. The work, while based on a simplified model, is supported by empirical results on real DNNs and highlights the interplay between data distribution, regularization, and the learned representation geometry in deep classification. This understanding can guide practitioners in diagnosing imbalance issues and potentially inform the design of more robust training strategies.