Neural Collapse in the Intermediate Hidden Layers of Classification Neural Networks
(2308.02760v1)
Published 5 Aug 2023 in cs.LG
Abstract: Neural Collapse (NC) gives a precise description of the representations of classes in the final hidden layer of classification neural networks. This description provides insights into how these networks learn features and generalize well when trained past zero training error. However, to date, (NC) has only been studied in the final layer of these networks. In the present paper, we provide the first comprehensive empirical analysis of the emergence of (NC) in the intermediate hidden layers of these classifiers. We examine a variety of network architectures, activations, and datasets, and demonstrate that some degree of (NC) emerges in most of the intermediate hidden layers of the network, where the degree of collapse in any given layer is typically positively correlated with the depth of that layer in the neural network. Moreover, we remark that: (1) almost all of the reduction in intra-class variance in the samples occurs in the shallower layers of the networks, (2) the angular separation between class means increases consistently with hidden layer depth, and (3) simple datasets require only the shallower layers of the networks to fully learn them, whereas more difficult ones require the entire network. Ultimately, these results provide granular insights into the structural propagation of features through classification neural networks.
This paper, "Neural Collapse in the Intermediate Hidden Layers of Classification Neural Networks" (Parker et al., 2023), investigates the phenomenon of Neural Collapse (NC) not just in the final hidden layer, as is common in previous research, but also in the intermediate layers of classification neural networks. NC describes a geometric structure that emerges in the final layer's representations during the Terminal Phase of Training (TPT), where within-class variability collapses, class means converge to an Equiangular Tight Frame (ETF) structure, and the classifier simplifies to a Nearest Class Center (NCC) rule.
The authors conduct an empirical analysis across various network architectures (MLP6, VGG11, ResNet18), activation functions (ReLU, Tanh, LeakyReLU), and datasets (MNIST, CIFAR10, CIFAR100, SVHN, FashionMNIST). They train the networks past zero training error using MSE loss and analyze the post-activation representations from intermediate layers at various training points.
To quantify NC in intermediate layer j, they compute four metrics based on the flattened post-activation vectors hi,cj:
NC1 (Intra-Class Variance Collapse): Measures the ratio of within-class covariance to between-class covariance using Tr(ΣBj+ΣWj/C). A lower value indicates less intra-class variance relative to between-class variance.
NC2 (Convergence to Simplex ETF):
Equal Norms: Quantified by the coefficient of variation of class mean norms stdcj(∥μcj−μGj∥2)/avgcj(∥μcj−μGj∥2). Approaches zero as norms become equal.
Maximal Angles: Measures the average deviation of pairwise class mean angles from the Simplex ETF angle (−1/(C−1)) using avgc=c′(∥μcj−μGj∥2∥μc′j−μGj∥2⟨μcj−μGjμc′j−μGj⟩+C−11). Approaches zero for maximal angular separation.
NC4 (Simplification to Nearest Class-Center): Calculates the proportion of samples where the j-th layer's nearest class mean classification disagrees with the final network's classification: 1−avgi,c1{fθ(xi,c)=argc′min∥hi,cj−μc′j∥2}. Approaches zero if the layer's representation is sufficient for classification via NCC.
The empirical results show that some degree of NC emerges in most intermediate layers, generally increasing with depth. Specifically:
NC1: Intra-class variance reduction primarily occurs in shallower layers, plateauing in deeper layers. This suggests that early layers are highly effective at grouping same-class samples.
NC2: Angular separation between class means consistently increases with layer depth throughout the network, indicating that deeper layers continue to refine the separation between different classes. Equal norms also show a similar trend but plateau earlier than angular separation.
NC4: The ability to classify correctly using a nearest class-center rule based on the intermediate layer's representation improves with depth. For simpler datasets like MNIST and SVHN, the NC4 mismatch reaches zero in shallower layers, implying these datasets can be fully learned by the initial layers. For more complex datasets like CIFAR100, this mismatch only reaches zero in the final layers, highlighting the need for the entire network depth.
CIFAR100 is noted as an outlier, showing different behavior, particularly less NC1 collapse.
Tanh and LeakyReLU activations exhibit similar trends to ReLU.
These findings provide a more granular understanding of how features evolve through a network: shallow layers reduce within-class variance, while deeper layers focus on increasing inter-class separation. The complexity of the dataset influences how much depth is required for the representations to support accurate classification via NCC.
Practical Implications:
Understanding the layer-wise emergence of NC can inform several practical aspects:
Feature Analysis: This research suggests that intermediate layers develop representations with different properties (NC1 in shallow, NC2 in deep). Analyzing these properties can provide insights into what kind of features are being learned at each stage.
Network Architecture Design: Knowing which layers contribute most to variance reduction vs. class separation could guide architectural choices, such as the design of blocks or the distribution of parameters.
Model Debugging and Interpretation: If NC metrics behave unexpectedly in certain layers during training (e.g., lack of NC1 in shallow layers), it might indicate issues with the network, data, or training process.
Model Pruning and Compression: For simpler datasets where NC4 is low in intermediate layers, it suggests that shallower layers might be sufficient for classification. This could support techniques like layer pruning or early-exit strategies to improve efficiency without significant performance loss.
Transfer Learning: Analyzing NC in pre-trained models might offer insights into which layers are most transferable or adaptable to new tasks. Layers with strong NC1 might contain more general feature extractors, while layers with strong NC2 are more task-specific.
Implementing the NC analysis described involves:
Saving model checkpoints during TPT.
Loading a checkpoint and running training data through the network with gradients disabled.
Collecting the post-activation outputs from desired intermediate layers.
Flattening these outputs into vectors.
Grouping the vectors by class for each layer.
Calculating the global mean (μGj), class means (μcj), within-class covariance (ΣWj), and between-class covariance (ΣBj) for each layer.
Computing the NC1, NC2 (Equal Norms and Maximal Angles), and NC4 metrics using the formulas provided in the paper. This requires singular value decomposition for the pseudoinverse in NC1.
Plotting these metrics against layer depth and training progress to observe the collapse dynamics.
Libraries like NumPy or PyTorch/TensorFlow can be used for the vector operations, covariance calculations, SVD, and metric computation.
For example, calculating NC1 for layer j in Python/PyTorch:
importtorchimportnumpyas np
defcompute_nc1(hidden_layer_outputs, labels):
""" Computes the NC1 metric for a given layer's outputs. Args: hidden_layer_outputs (torch.Tensor): Tensor of shape (N, D) where N is the number of samples and D is the feature dimension of the layer. labels (torch.Tensor): Tensor of shape (N,) containing class labels (0 to C-1). Returns: float: The computed NC1 value."""
C = int(labels.max().item() + 1)
D = hidden_layer_outputs.shape[1]
device = hidden_layer_outputs.device
mu_G = torch.mean(hidden_layer_outputs, dim=0) # Global mean
mu_c = torch.zeros(C, D, device=device)
Sigma_W = torch.zeros(D, D, device=device)
counts = torch.zeros(C, device=device)
for c inrange(C):
class_samples = hidden_layer_outputs[labels == c]
counts[c] = class_samples.shape[0]
if counts[c] > 0:
mu_c[c] = torch.mean(class_samples, dim=0) # Class mean
centered_samples = class_samples - mu_c[c]
# Sum of (h_i - mu_c)(h_i - mu_c)^T for class c
Sigma_W_c = torch.matmul(centered_samples.transpose(0, 1), centered_samples)
Sigma_W += Sigma_W_c # Sum across classes (will divide by N later)# Average Sigma_W over all samples N
N = hidden_layer_outputs.shape[0]
if N > 0:
Sigma_W /= N
else:
returnfloat('nan') # Handle case with no samples# Calculate Sigma_B
centered_mu_c = mu_c - mu_G
# Sigma_B = Ave_c (mu_c - mu_G)(mu_c - mu_G)^T# This is actually sum_c counts[c]/N * (mu_c - mu_G)(mu_c - mu_G)^T# A simpler approximation often used is 1/C * sum_c (mu_c - mu_G)(mu_c - mu_G)^T# The paper's definition says Ave_c, which implies average over classes
Sigma_B = torch.matmul(centered_mu_c.transpose(0, 1), centered_mu_c) / C
# Compute pseudoinverse of Sigma_B# Use numpy for svd/pinv for simplicity if needed, or PyTorch's svd# It's important to handle potential numerical stability issues and zero eigenvaluestry:
Sigma_B_np = Sigma_B.detach().cpu().numpy()
Sigma_B_pinv_np = np.linalg.pinv(Sigma_B_np)
Sigma_B_pinv = torch.from_numpy(Sigma_B_pinv_np).to(device)
# NC1 = Trace(Sigma_B_pinv * Sigma_W / C)
trace_val = torch.trace(torch.matmul(Sigma_B_pinv, Sigma_W)) / C
return trace_val.item()
except np.linalg.LinAlgError:
print("Warning: SVD did not converge.")
returnfloat('nan')
exceptExceptionas e:
print(f"Error computing NC1: {e}")
returnfloat('nan')
######
This analysis provides valuable insights into the internal working of neural networks, particularly how features are processed layer by layer to achieve classification, and could potentially inform strategies for improving model efficiency and interpretability. The open question of how these observations generalize to unseen test data is highlighted as an important direction for future work.