Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
110 tokens/sec
GPT-4o
56 tokens/sec
Gemini 2.5 Pro Pro
44 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Neural Collapse in the Intermediate Hidden Layers of Classification Neural Networks (2308.02760v1)

Published 5 Aug 2023 in cs.LG

Abstract: Neural Collapse (NC) gives a precise description of the representations of classes in the final hidden layer of classification neural networks. This description provides insights into how these networks learn features and generalize well when trained past zero training error. However, to date, (NC) has only been studied in the final layer of these networks. In the present paper, we provide the first comprehensive empirical analysis of the emergence of (NC) in the intermediate hidden layers of these classifiers. We examine a variety of network architectures, activations, and datasets, and demonstrate that some degree of (NC) emerges in most of the intermediate hidden layers of the network, where the degree of collapse in any given layer is typically positively correlated with the depth of that layer in the neural network. Moreover, we remark that: (1) almost all of the reduction in intra-class variance in the samples occurs in the shallower layers of the networks, (2) the angular separation between class means increases consistently with hidden layer depth, and (3) simple datasets require only the shallower layers of the networks to fully learn them, whereas more difficult ones require the entire network. Ultimately, these results provide granular insights into the structural propagation of features through classification neural networks.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (4)
  1. Liam Parker (12 papers)
  2. Emre Onal (2 papers)
  3. Anton Stengel (2 papers)
  4. Jake Intrater (1 paper)
Citations (10)

Summary

This paper, "Neural Collapse in the Intermediate Hidden Layers of Classification Neural Networks" (Parker et al., 2023 ), investigates the phenomenon of Neural Collapse (NC\mathcal{NC}) not just in the final hidden layer, as is common in previous research, but also in the intermediate layers of classification neural networks. NC\mathcal{NC} describes a geometric structure that emerges in the final layer's representations during the Terminal Phase of Training (TPT), where within-class variability collapses, class means converge to an Equiangular Tight Frame (ETF) structure, and the classifier simplifies to a Nearest Class Center (NCC) rule.

The authors conduct an empirical analysis across various network architectures (MLP6, VGG11, ResNet18), activation functions (ReLU, Tanh, LeakyReLU), and datasets (MNIST, CIFAR10, CIFAR100, SVHN, FashionMNIST). They train the networks past zero training error using MSE loss and analyze the post-activation representations from intermediate layers at various training points.

To quantify NC\mathcal{NC} in intermediate layer jj, they compute four metrics based on the flattened post-activation vectors hi,cj\bm h^j_{i, c}:

  1. NC1\mathcal{NC}1 (Intra-Class Variance Collapse): Measures the ratio of within-class covariance to between-class covariance using Tr(ΣBj+ΣWj/C)\operatorname{Tr}(\bm\Sigma_B^{j +} \bm\Sigma_W^j/C). A lower value indicates less intra-class variance relative to between-class variance.
  2. NC2\mathcal{NC}2 (Convergence to Simplex ETF):
    • Equal Norms: Quantified by the coefficient of variation of class mean norms stdcj(μcjμGj2)/avgcj(μcjμGj2)\operatorname{std}^j_c(\|\bm\mu_c^j - \bm\mu_G^j\|_2)/\operatorname{avg}^j_c(\|\bm\mu_c^j-\bm\mu_G^j\|_2). Approaches zero as norms become equal.
    • Maximal Angles: Measures the average deviation of pairwise class mean angles from the Simplex ETF angle (1/(C1))(-1/(C-1)) using avgcc(μcjμGjμcjμGjμcjμGj2μcjμGj2+1C1)\operatorname{avg}_{c\neq c'}\left(\left| \frac{\langle\bm\mu_c^j-\mu_G^j \bm\mu_{c'}^j-\bm\mu_G^j\rangle}{\|\bm\mu_c^j-\bm\mu_G^j\|_2\|\bm\mu_{c'}^j-\bm\mu_G^j\|_2} + \frac{1}{C-1}\right|\right). Approaches zero for maximal angular separation.
  3. NC4\mathcal{NC}4 (Simplification to Nearest Class-Center): Calculates the proportion of samples where the jj-th layer's nearest class mean classification disagrees with the final network's classification: 1avgi,c1{fθ(xi,c)=argminchi,cjμcj2}1 - \operatorname{avg}_{i,c}\mathbb{1}\{f_\theta(x_{i,c}) = \arg\min_{c'}\| \bm h^j_{i,c}-\bm \mu_{c'}^j\|_2\}. Approaches zero if the layer's representation is sufficient for classification via NCC.

The empirical results show that some degree of NC\mathcal{NC} emerges in most intermediate layers, generally increasing with depth. Specifically:

  • NC1\mathcal{NC}1: Intra-class variance reduction primarily occurs in shallower layers, plateauing in deeper layers. This suggests that early layers are highly effective at grouping same-class samples.
  • NC2\mathcal{NC}2: Angular separation between class means consistently increases with layer depth throughout the network, indicating that deeper layers continue to refine the separation between different classes. Equal norms also show a similar trend but plateau earlier than angular separation.
  • NC4\mathcal{NC}4: The ability to classify correctly using a nearest class-center rule based on the intermediate layer's representation improves with depth. For simpler datasets like MNIST and SVHN, the NC4\mathcal{NC}4 mismatch reaches zero in shallower layers, implying these datasets can be fully learned by the initial layers. For more complex datasets like CIFAR100, this mismatch only reaches zero in the final layers, highlighting the need for the entire network depth.
  • CIFAR100 is noted as an outlier, showing different behavior, particularly less NC1\mathcal{NC}1 collapse.
  • Tanh and LeakyReLU activations exhibit similar trends to ReLU.

These findings provide a more granular understanding of how features evolve through a network: shallow layers reduce within-class variance, while deeper layers focus on increasing inter-class separation. The complexity of the dataset influences how much depth is required for the representations to support accurate classification via NCC.

Practical Implications:

Understanding the layer-wise emergence of NC\mathcal{NC} can inform several practical aspects:

  • Feature Analysis: This research suggests that intermediate layers develop representations with different properties (NC1\mathcal{NC}1 in shallow, NC2\mathcal{NC}2 in deep). Analyzing these properties can provide insights into what kind of features are being learned at each stage.
  • Network Architecture Design: Knowing which layers contribute most to variance reduction vs. class separation could guide architectural choices, such as the design of blocks or the distribution of parameters.
  • Model Debugging and Interpretation: If NC\mathcal{NC} metrics behave unexpectedly in certain layers during training (e.g., lack of NC1\mathcal{NC}1 in shallow layers), it might indicate issues with the network, data, or training process.
  • Model Pruning and Compression: For simpler datasets where NC4\mathcal{NC}4 is low in intermediate layers, it suggests that shallower layers might be sufficient for classification. This could support techniques like layer pruning or early-exit strategies to improve efficiency without significant performance loss.
  • Transfer Learning: Analyzing NC\mathcal{NC} in pre-trained models might offer insights into which layers are most transferable or adaptable to new tasks. Layers with strong NC1\mathcal{NC}1 might contain more general feature extractors, while layers with strong NC2\mathcal{NC}2 are more task-specific.

Implementing the NC\mathcal{NC} analysis described involves:

  1. Saving model checkpoints during TPT.
  2. Loading a checkpoint and running training data through the network with gradients disabled.
  3. Collecting the post-activation outputs from desired intermediate layers.
  4. Flattening these outputs into vectors.
  5. Grouping the vectors by class for each layer.
  6. Calculating the global mean (μGj\bm\mu_G^j), class means (μcj\bm\mu_c^j), within-class covariance (ΣWj\Sigma_W^j), and between-class covariance (ΣBj\Sigma_B^j) for each layer.
  7. Computing the NC1\mathcal{NC}1, NC2\mathcal{NC}2 (Equal Norms and Maximal Angles), and NC4\mathcal{NC}4 metrics using the formulas provided in the paper. This requires singular value decomposition for the pseudoinverse in NC1\mathcal{NC}1.
  8. Plotting these metrics against layer depth and training progress to observe the collapse dynamics.

Libraries like NumPy or PyTorch/TensorFlow can be used for the vector operations, covariance calculations, SVD, and metric computation.

For example, calculating NC1\mathcal{NC}1 for layer j in Python/PyTorch:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import numpy as np

def compute_nc1(hidden_layer_outputs, labels):
    """
    Computes the NC1 metric for a given layer's outputs.

    Args:
        hidden_layer_outputs (torch.Tensor): Tensor of shape (N, D) where N is the number
                                            of samples and D is the feature dimension
                                            of the layer.
        labels (torch.Tensor): Tensor of shape (N,) containing class labels (0 to C-1).

    Returns:
        float: The computed NC1 value.
    """
    C = int(labels.max().item() + 1)
    D = hidden_layer_outputs.shape[1]
    device = hidden_layer_outputs.device

    mu_G = torch.mean(hidden_layer_outputs, dim=0) # Global mean

    mu_c = torch.zeros(C, D, device=device)
    Sigma_W = torch.zeros(D, D, device=device)
    counts = torch.zeros(C, device=device)

    for c in range(C):
        class_samples = hidden_layer_outputs[labels == c]
        counts[c] = class_samples.shape[0]
        if counts[c] > 0:
            mu_c[c] = torch.mean(class_samples, dim=0) # Class mean
            centered_samples = class_samples - mu_c[c]
            # Sum of (h_i - mu_c)(h_i - mu_c)^T for class c
            Sigma_W_c = torch.matmul(centered_samples.transpose(0, 1), centered_samples)
            Sigma_W += Sigma_W_c # Sum across classes (will divide by N later)

    # Average Sigma_W over all samples N
    N = hidden_layer_outputs.shape[0]
    if N > 0:
        Sigma_W /= N
    else:
        return float('nan') # Handle case with no samples

    # Calculate Sigma_B
    centered_mu_c = mu_c - mu_G
    # Sigma_B = Ave_c (mu_c - mu_G)(mu_c - mu_G)^T
    # This is actually sum_c counts[c]/N * (mu_c - mu_G)(mu_c - mu_G)^T
    # A simpler approximation often used is 1/C * sum_c (mu_c - mu_G)(mu_c - mu_G)^T
    # The paper's definition says Ave_c, which implies average over classes
    Sigma_B = torch.matmul(centered_mu_c.transpose(0, 1), centered_mu_c) / C

    # Compute pseudoinverse of Sigma_B
    # Use numpy for svd/pinv for simplicity if needed, or PyTorch's svd
    # It's important to handle potential numerical stability issues and zero eigenvalues
    try:
        Sigma_B_np = Sigma_B.detach().cpu().numpy()
        Sigma_B_pinv_np = np.linalg.pinv(Sigma_B_np)
        Sigma_B_pinv = torch.from_numpy(Sigma_B_pinv_np).to(device)

        # NC1 = Trace(Sigma_B_pinv * Sigma_W / C)
        trace_val = torch.trace(torch.matmul(Sigma_B_pinv, Sigma_W)) / C
        return trace_val.item()
    except np.linalg.LinAlgError:
         print("Warning: SVD did not converge.")
         return float('nan')
    except Exception as e:
        print(f"Error computing NC1: {e}")
        return float('nan')

#
#
#
#
#
#

This analysis provides valuable insights into the internal working of neural networks, particularly how features are processed layer by layer to achieve classification, and could potentially inform strategies for improving model efficiency and interpretability. The open question of how these observations generalize to unseen test data is highlighted as an important direction for future work.