Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
169 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
45 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

A Neural Collapse Perspective on Feature Evolution in Graph Neural Networks (2307.01951v2)

Published 4 Jul 2023 in cs.LG, cs.AI, cs.IT, math.IT, math.OC, and stat.ML

Abstract: Graph neural networks (GNNs) have become increasingly popular for classification tasks on graph-structured data. Yet, the interplay between graph topology and feature evolution in GNNs is not well understood. In this paper, we focus on node-wise classification, illustrated with community detection on stochastic block model graphs, and explore the feature evolution through the lens of the "Neural Collapse" (NC) phenomenon. When training instance-wise deep classifiers (e.g. for image classification) beyond the zero training error point, NC demonstrates a reduction in the deepest features' within-class variability and an increased alignment of their class means to certain symmetric structures. We start with an empirical study that shows that a decrease in within-class variability is also prevalent in the node-wise classification setting, however, not to the extent observed in the instance-wise case. Then, we theoretically study this distinction. Specifically, we show that even an "optimistic" mathematical model requires that the graphs obey a strict structural condition in order to possess a minimizer with exact collapse. Interestingly, this condition is viable also for heterophilic graphs and relates to recent empirical studies on settings with improved GNNs' generalization. Furthermore, by studying the gradient dynamics of the theoretical model, we provide reasoning for the partial collapse observed empirically. Finally, we present a study on the evolution of within- and between-class feature variability across layers of a well-trained GNN and contrast the behavior with spectral methods.

Citations (7)

Summary

  • The paper demonstrates that in GNNs trained on SBM models, within-class feature collapse is only partial, aligning with theoretical predictions on graph structure.
  • It introduces a graph-based Unconstrained Features Model to show that exact collapse requires strict neighborhood uniformity—a condition rarely met in real-world graphs.
  • The paper also reveals that deeper GNN layers progressively reduce within-class variability, thus enhancing feature separability and mitigating over-smoothing.

Graph neural networks (GNNs) have become standard tools for tasks on graph-structured data, but the precise interaction between graph topology and how GNNs learn node features is still an active research area. This paper (2307.01951) investigates this interaction by examining the phenomenon of Neural Collapse (NC) in GNNs, particularly in the context of supervised node classification, using the stochastic block model (SBM) as a controlled environment.

Neural Collapse, originally observed in standard deep neural networks (DNNs) trained for instance-wise classification beyond zero training error (the terminal phase of training or TPT), describes several properties: (NC1) within-class feature variability collapses (features of samples from the same class converge to their class mean), (NC2) class means align to specific symmetric structures (like a simplex equiangular tight frame), and (NC3) the last layer classifier weights align with the class means. A practical consequence is that the decision boundary approximates a nearest class center classifier. This paper adapts NC metrics to the node-wise setting, defining within- and between-class covariance matrices (ΣW()\mathbf{\Sigma}_W(\cdot), ΣB()\mathbf{\Sigma}_B(\cdot)) and NC1 metrics (C1\mathcal{C}_1, C~1\widetilde{\mathcal{C}}_1) for both the standard node features (H\mathbf{H}) and neighborhood-aggregated features (HA^\mathbf{H}\widehat{\mathbf{A}}, where A^\widehat{\mathbf{A}} is the normalized adjacency matrix).

The authors conduct empirical studies on GNNs trained for community detection on SSBM graphs. They observe that during TPT, the deepest layer features in GNNs exhibit a decrease in within-class variability, similar to standard DNNs, but the collapse is only partial, meaning the NC1 metrics plateau at a value significantly above zero. This partial collapse is observed for both tested GNN architectures, one using the I+A^\mathbf{I} + \widehat{\mathbf{A}} operator (ψΘF\psi_\Theta^\mathcal{F}) and another using only A^\widehat{\mathbf{A}} (ψΘF\psi_\Theta^{\mathcal{F}'}).

To theoretically understand this distinction from plain DNNs, the paper introduces a graph-based Unconstrained Features Model (gUFM). This model treats the deepest layer features as free optimization variables but incorporates the graph structure via the A^\widehat{\mathbf{A}} matrix in the objective function, which is the mean squared error (MSE) loss combined with L2 regularization. The analysis of the gUFM minimizers reveals a key theoretical finding (Theorem 3.1): for a graph to possess a minimizer with exact collapse (zero within-class variability), it must satisfy a strict structural condition (Condition C). This condition states that for any given class, all nodes within that class must have the same distribution of neighbors across all classes. For instance, in a 2-class setting, every node in class 1 must connect to the same number of nodes in class 1 and the same number of nodes in class 2 as any other node in class 1 (though these counts can differ for nodes in class 2).

The paper then demonstrates (Theorem 3.2) that graphs sampled from the SSBM distribution (which are used to model real-world network structures) are highly unlikely to satisfy Condition C, especially as the graph size increases. This theoretical result provides a strong explanation for the empirical observation that NC in GNNs is typically only partial when trained on practical graph structures.

Despite the lack of exact collapse in minimizers for typical graphs, the authors paper the gradient dynamics of the gUFM (Theorem 3.3) to explain the partial collapse observed empirically. For a simplified case (2 classes, small perturbation from the expected SSBM graph), they show that training along the gradient flow still leads to a decrease in the trace of the within-class covariance matrix and an increase in the trace of the between-class covariance matrix, which are key components of NC1. This indicates that while exact collapse is rare, the training process naturally drives the features towards better class separability by reducing within-class scatter.

The paper also investigates feature evolution across layers during inference on well-trained GNNs. Empirically, they find that NC1 metrics progressively decrease from the input layer to the final layer, suggesting that deeper layers learn more separable features. This depthwise behavior is contrasted with spectral clustering methods (like Normalized Laplacian and Bethe-Hessian applied via projected power iterations). While spectral methods also implicitly separate features, the paper shows empirical differences in how within- and between-class covariance traces evolve across layers/iterations. GNNs show decreasing ratios of successive layer covariance traces (Tr(Σ(l))/Tr(Σ(l1))\text{Tr}(\mathbf{\Sigma}^{(l)})/\text{Tr}(\mathbf{\Sigma}^{(l-1)})), unlike spectral methods where these ratios remain constant. Theoretical analysis (Theorem 4.1) provides bounds on these trace ratios for GNN layers, showing their dependence on weight matrices and graph properties, which helps explain the observed differences between the ψΘF\psi_\Theta^\mathcal{F} and ψΘF\psi_\Theta^{\mathcal{F}'} architectures. The I\mathbf{I} term in ψΘF\psi_\Theta^\mathcal{F} (identity connection) seems to slow down the reduction in within-class variability compared to ψΘF\psi_\Theta^{\mathcal{F}'}.

Practical Implications and Implementation Considerations:

  • Understanding GNN Limitations: The finding that exact NC requires strict graph structures highlights that standard NC might not be the best lens for evaluating GNNs on arbitrary graphs. However, partial NC, indicating improved feature separability, is a valuable training outcome.
  • Connection to Graph Structure & Generalization: Condition C, which ensures uniform neighborhood structures within classes, aligns with empirical findings suggesting that such neighborhood patterns are crucial for GNN performance, even on heterophilic graphs [Ma et al., ICLR 2022]. This work provides a theoretical link between graph structure, feature collapse, and potentially generalization.
  • Mitigating Over-smoothing: The desirable outcome of decreasing within-class variability and increasing between-class variability observed in NC is inherently linked to combating over-smoothing, where features of all nodes converge regardless of class (Appendix A.1). Promoting partial NC could be a strategy.
  • Graph Rewiring: The theoretical insights on Condition C suggest that graph rewiring techniques (Appendix A.2) could be explored to modify input graphs to better satisfy this condition, potentially improving GNN performance and inducing stronger collapse.
  • Model Implementation: The paper uses a standard message-passing GNN architecture.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch_geometric.nn import MessagePassing
    from torch_geometric.utils import add_self_loops, degree
    
    class GNNLayer(MessagePassing):
        def __init__(self, in_channels, out_channels, use_identity=True):
            super(GNNLayer, self).__init__(aggr='add') # "Add" aggregation (sum)
            self.linear1 = nn.Linear(in_channels, out_channels) if use_identity else None
            self.linear2 = nn.Linear(in_channels, out_channels)
            self.use_identity = use_identity
    
        def forward(self, x, edge_index):
            # Add self-loops to the adjacency matrix.
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
    
            # Calculate degree matrix
            row, col = edge_index
            deg = degree(col, x.size(0), dtype=x.dtype)
            deg_inv_sqrt = deg.pow(-0.5)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
            norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
    
            # Message passing
            out = self.propagate(edge_index, x=x, norm=norm)
    
            # Linear transformation (based on paper's eq 3 & 4)
            if self.use_identity: # psi_Theta^F
                 # H^{(l)} = W1 H^{(l-1)} + W2 H^{(l-1)} A_hat
                 # Here, out corresponds to H^{(l-1)} A_hat
                 out = self.linear1(x) + self.linear2(out)
            else: # psi_Theta^F'
                 # H^{(l)} = W2 H^{(l-1)} A_hat
                 out = self.linear2(out)
    
            return out
    
        def message(self, x_j, norm):
            # Normalized message
            return norm.view(-1, 1) * x_j
    
    # Example Usage:
    # in_dim, hidden_dim, out_dim = ...
    # layer = GNNLayer(in_dim, hidden_dim, use_identity=True)
    # x = torch.randn(num_nodes, in_dim) # Node features
    # edge_index = torch.tensor([...], dtype=torch.long) # Adjacency list
    # x_next = layer(x, edge_index)
    # x_next = F.relu(x_next) # Apply ReLU
    # # Then apply Instance Normalization as in the paper
  • Training with MSE Loss: The paper highlights that MSE loss is used for theoretical tractability and shows comparable performance to cross-entropy loss empirically for classification. This is a practical choice for implementers aiming for theoretical analysis or exploring alternative loss functions.
  • Handling Deep Architectures: Instance normalization is crucial for training deep GNNs, as noted in the paper and common in practice to mitigate issues like over-smoothing or training instability.
  • Computational Costs: Computing NC metrics involves matrix operations (covariance, pseudo-inverse, trace) on potentially large feature matrices (d×Nd \times N), which can be computationally expensive, especially for large graphs (NN) or high-dimensional features (dd). This should be considered for large-scale applications.

The paper concludes by suggesting future work including formally linking NC behavior to generalization in GNNs and exploring graph rewiring strategies to induce desirable NC properties. The findings regarding Condition C and its rarity provide a foundational understanding of why GNNs may exhibit partial rather than exact collapse, distinguishing their feature learning behavior from standard DNNs.