Papers
Topics
Authors
Recent
Search
2000 character limit reached

GSINA Framework: Graph Sinkhorn Attention

Updated 7 February 2026
  • GSINA is an optimal transport-based framework that extracts invariant subgraphs by balancing sparsity, softness, and differentiability for robust graph learning.
  • It formulates the extraction as a cardinality-constrained optimal transport problem solved by the entropic Sinkhorn algorithm to enable a differentiable, soft, and sparse attention mechanism.
  • Empirical results demonstrate that GSINA improves graph and node-level tasks, boosting classification accuracy by up to 10% and ROC-AUC by 1–2 points on various benchmarks.

Graph Sinkhorn Attention (GSINA) is an Optimal Transport-based attention framework designed for extracting invariant subgraphs in Graph Invariant Learning (GIL) settings. GSINA addresses the challenge of out-of-distribution (OOD) generalization in graph learning by selecting subgraphs whose relationship to predicted labels remains stable across multiple, unseen environments. The framework formulates subgraph extraction as a cardinality-constrained Optimal Transport problem, solved efficiently using the entropic Sinkhorn algorithm, yielding a fully differentiable, soft, and sparse attention mechanism for graph neural networks (GNNs) (Ding et al., 2024).

1. Motivation and Problem Definition

GSINA is developed in the context of Graph Invariant Learning, where the goal is to construct predictors that minimize the worst-case risk across multiple unknown environments. Consider independent samples (Gie,Yie)(G_i^e, Y_i^e) for unlabeled environment ee; the objective is to find

f=argminf  maxe  E(G,Y)Ge[(f(G),Y)]f^* = \arg\min_{f}\;\max_{e}\;\mathbb{E}_{(G,Y)\sim\mathcal{G}^e}[\ell(f(G),Y)]

Since environment labels ee are unobserved, GIL approaches extract an invariant subgraph GSGG_S \subseteq G whose relationship to YY is presumed stable under distributional shifts. The subgraph extraction process focuses on discarding spurious or environment-specific graph structure, retaining only invariant, label-relevant nodes and edges (Ding et al., 2024).

2. Design Principles for Invariant Subgraph Extraction

GSINA is derived from three essential design principles for subgraph extractors:

  • Sparsity: The selected subgraph should be small, retaining few nodes and edges to ensure that non-invariant and noisy graph components are filtered out.
  • Softness: Instead of hard selections (such as top-kk edges), the framework assigns continuous attention weights αeE\alpha^E_e in [0,1][0,1] to each edge, ensuring an enlarged solution space and preserving differentiability.
  • Differentiability: End-to-end differentiability is necessary so that both the subgraph mask and the predictor can be optimized jointly via gradient-based algorithms.

GSINA contrasts with earlier approaches: Information Bottleneck-based (IB) methods (e.g., GSAT) are soft and differentiable but lack enforced sparsity; top-kk methods (e.g., CIGA) ensure sparsity but use hard, non-differentiable selection. GSINA unifies sparsity, softness, and differentiability.

3. Methodology: Graph Sinkhorn Attention

The GSINA framework uses an Optimal Transport abstraction to implement a soft, sparse, and fully differentiable top-rr edge selection mechanism:

Edge Scoring

  • Node representations {ni}\{n_i\} are generated via a lightweight GNN (GNNϕ\mathrm{GNN}_\phi).
  • Each edge e=(i,j)e=(i,j) is assigned a score se=MLPϕ(ni,nj)s_e = \mathrm{MLP}_\phi(n_i, n_j).

OT Formulation

  • With NeN_e total edges, the objective is to allocate rNer N_e (invariant mass) to the highest-scoring edges.
  • Define the cost matrix DR2×NeD \in \mathbb{R}^{2 \times N_e} as:

D=[s~1min(s)s~Nemin(s) max(s)s~1max(s)s~Ne]D = \begin{bmatrix} \tilde s_1 - \min(s) & \ldots & \tilde s_{N_e} - \min(s) \ \max(s) - \tilde s_1 & \ldots & \max(s) - \tilde s_{N_e} \end{bmatrix}

where s~e=seσlog(logue)\tilde s_e = s_e - \sigma \log(-\log u_e) introduces Gumbel noise (ueU(0,1)u_e \sim U(0,1), σ>0\sigma > 0 during training).

  • Marginals are set as R=[(1r)Ne,rNe]R = [(1-r)N_e,\, r N_e]^\top and C=[1,1,,1]RNeC = [1,1,\dots,1]^\top \in \mathbb{R}^{N_e}.
  • The optimal transport plan TT solves:

minT[0,1]2×NeT,DτH(T)\min_{T \in [0,1]^{2 \times N_e}} \langle T, D \rangle - \tau H(T)

subject to T1=RT\mathbf{1} = R, T1=CT^\top \mathbf{1} = C, with H(T)=i,jTijlogTijH(T) = -\sum_{i,j} T_{ij} \log T_{ij} and entropy regularization τ\tau controlling softness.

Sinkhorn Normalization

  • Initialize K=exp(D/τ)K = \exp(-D / \tau); iteratively normalize rows/columns to match marginals (10 iterations typical) using:

u(t)=R/(Kv(t1)),v(t)=C/(Ku(t)),T(t+1)=diag(u(t))Kdiag(v(t))u^{(t)} = R / (K v^{(t-1)}), \quad v^{(t)} = C / (K^\top u^{(t)}), \quad T^{(t+1)} = \operatorname{diag}(u^{(t)})\, K\, \operatorname{diag}(v^{(t)})

This produces an approximate solution TT.

Extracting Attention Weights

  • The first row of TT yields edge attention: αE=T[1,:]RNe\alpha^E = T[1,:] \in \mathbb{R}^{N_e}.
  • Node attention is computed by aggregating incident edge attention, e.g., αiV=max(i,j)Eα(i,j)E\alpha^V_i = \max_{(i,j) \in \mathcal{E}} \alpha^E_{(i,j)}.

4. Integration with Graph Neural Networks

GSINA functions as a modular attention layer positioned between the GNN feature extractor and the final predictor:

  • Message Passing: Per GNN layer ll, messages from neighbors jNij \in \mathcal{N}_i to node ii are modulated by αijE\alpha^E_{ij}:

hi(l+1)=UPDATE(hi(l),jNiαijEMSG(hi(l),hj(l)))h_i^{(l+1)} = \mathrm{UPDATE}\left( h_i^{(l)}, \sum_{j \in \mathcal{N}_i} \alpha^E_{ij} \mathrm{MSG}(h_i^{(l)}, h_j^{(l)}) \right)

  • Readout: After LL layers, node features are aggregated by node attention:

hG=READOUT({αiVhi(L)}iV)h_G = \mathrm{READOUT}(\{ \alpha^V_i\, h_i^{(L)} \}_{i \in \mathcal{V}})

with final prediction Y^=Pθ(YhG)\hat Y = P_\theta(Y | h_G).

  • Training: End-to-end optimization is performed by backpropagating through all GSINA operations, including the Sinkhorn normalization.

5. Hyperparameters and Regularization

Several hyperparameters control the operation and inductive bias of GSINA:

Parameter Description Typical Range
τ\tau Entropy regularization; higher yields smoother attention τ1\tau\approx1, tune
rr Fraction of total edge mass for invariant subgraph r[0.2,0.8]r\in[0.2,0.8]
σ\sigma Gumbel noise scaling for exploration during training σ>0\sigma>0 (train)

Smaller τ\tau approaches hard (binary) selection, while larger τ\tau yields softer masks. Gumbel noise (σ\sigma) is applied during training to escape poor local minima. In practice, τ1\tau \approx 1 and rr in [0.2,0.8][0.2, 0.8] work effectively when chosen via validation (Ding et al., 2024).

6. Empirical Results and Ablation Studies

GSINA achieves state-of-the-art results on both graph-level and node-level OOD benchmarks:

  • Graph-level tasks: On synthetic Spurious-Motif (b=0.5,0.7,0.9b=0.5,0.7,0.9), MNIST-75sp, Graph-SST2, OGBG-MolHIV, and additional molecular datasets, GSINA (with GIN or PNA backbones) surpasses other GIL methods. Notably, it outperforms GSAT by up to 10%10\% classification accuracy on Spurious-Motif and improves ROC-AUC by $1$–$2$ points on molecular datasets, using metrics such as ACC and ROC-AUC.
  • Node-level tasks: On datasets such as Cora, Amazon-Photo, Twitch, Facebook-100, Elliptic, and OGB-ArXiv, GSINA yields substantial gains. For instance, GSINA improves upon ERM and matches or exceeds EERM in most situations (e.g., +20%+20\% ACC on Cora).

Ablation experiments indicate that omitting either Gumbel noise or node attention decreases performance by $5$–10%10\% ACC on Spurious-Motif, demonstrating the criticality of both softness and multi-level attention.

7. Analysis, Limitations, and Future Directions

The superior generalization of GSINA is attributed to its unique balance of sparsity (removing spurious substructure), softness (enabling a rich solution space and stable gradients), and full differentiability (jointly optimizing masks and predictions). However, sensitivity to rr selection exists, and its OT-based top-rr formulation may underperform information bottleneck approaches on certain interpretability metrics relying on hard, binary subgraph extraction. Future work includes integrating explicit connectivity or completeness constraints, learning rr or τ\tau jointly, and extending the framework for causal invariance discovery at node and edge levels (Ding et al., 2024).

In summary, GSINA advances GIL by introducing an OT-based attention model that is simultaneously sparse, soft, and differentiable, enabling robust and interpretable OOD generalization across diverse graph learning tasks.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to GSINA Framework.