CKA-SR Regularization for Deep Learning
- CKA-SR Regularization is a spectral technique that uses centered kernel alignment to measure and control the similarity of internal feature maps in deep networks.
- It enhances training by integrating a CKA-based loss, reducing mutual information and compressing representations to promote sparsity and efficiency.
- The method has proven effective across sparse training, Bayesian ensembles, and federated learning, yielding improved accuracy and robustness.
CKA-SR Regularization is a family of spectral regularization techniques leveraging Centered Kernel Alignment (CKA) as a functional similarity metric for neural network representations. The core motivation is to directly intervene on the similarity structure of internal feature maps—across layers, models, or distributed learners—enabling improved sparsity, diversity, robustness, and transfer in deep learning. The most widely explored instantiations are CKA-SR for sparse training, for diversity in Bayesian ensembles, and for representation alignment in federated learning.
1. Mathematical Foundations of Centered Kernel Alignment
Centered Kernel Alignment (CKA) is a normalized similarity measure for comparing sets of representations. For feature maps and (each row an example), CKA operates via Gram matrices:
- and .
- Centering: for , use , forming .
- Hilbert–Schmidt Independence Criterion: for the linear kernel .
- The normalized CKA:
This yields a value in , reflecting representational similarity invariant to orthogonal transformations and isotropic scaling (Ni et al., 2023, Smerkous et al., 31 Oct 2024, Son et al., 2021).
2. CKA-SR Loss Construction and Sparse Training
CKA-SR was introduced to induce sparsity by minimizing inter-layer feature similarity in deep networks (Ni et al., 2023). The regularizer is formulated as:
where are layer features, and are layer-pair weights. The total loss is:
Here, is a regularization parameter tuned in . Typical layer pairs include adjacent or within-stage layers, reducing computational overhead to in network depth. CKA is estimated on a small “few-shot” subset ( samples per batch) for efficiency. CKA-SR can be combined with or weight regularization.
3. Information-Theoretic Link: Compression and Sparsity
The adoption of CKA-SR in sparse learning is theoretically grounded by the information bottleneck principle (Ni et al., 2023):
- For Gaussian features, minimizing inter-layer CKA aligns with minimizing the mutual information .
- Reducing (compression) enforces sparsity by driving down for the layer’s weight matrix , maximizing the number of zeros (“-sparsity”).
- The established chain is:
This direct connection enables layer-wise control of both compressive representations and parameter sparsity.
4. CKA-SR Extensions: FedCKA and Hyperspherical Repulsion
CKA-based regularization extends to federated and Bayesian deep learning scenarios.
FedCKA: Federated Learning on Heterogeneous Data
FedCKA applies CKA regularization to federated setups, aligning representations of selected “important” layers (typically first two convolutional layers or initial blocks) between local models and the global server model (Son et al., 2021). The local objective combines supervised loss and a CKA-based contrastive term:
- For layer , the key quantity is:
Regularization is restricted to early layers, avoiding constraints on task-specialized deeper layers, and preserving scalability and efficiency.
CKA-SR with Hyperspherical Energy in Bayesian Ensembles
Smerkous et al. introduce hyperspherical energy (HE) minimization over CKA-centered Gram vectors to enforce diversity of particle-based Bayesian deep ensembles (Smerkous et al., 31 Oct 2024):
- Gram matrices of layer features are centered, flattened, and normalized to sit on the unit sphere.
- HE is computed using geodesic (arccosine) distance, yielding an energy sum over all model pairs and layers:
where .
- Gradients are non-vanishing even as models become similar, ensuring repulsive diversity when standard cosine repulsion may fail.
- The framework also incorporates OOD-specific HE and entropy terms to penalize overconfident predictions on synthetic outliers.
- Typical hyperparameters: (ensemble size), , , , .
5. Computational Strategies and Algorithmic Procedures
CKA-SR methods maintain scalability via several implementation tactics:
- Use linear CKA (Frobenius norm based) for efficient computation: for samples and layer pair.
- Restrict to intra-stage or adjacent pairs of layers, keeping regularization terms linear in depth.
- In federated learning (FedCKA), only selected layers are regularized, with overall per-round overhead ≈17% relative to FedAvg on deep models.
- In Bayesian ensemble settings, gradient stability is maintained via smoothers (, ) and weighted layer contributions.
CKA-SR Pseudocode for Sparse Training (Ni et al., 2023)
1 2 3 4 5 6 7 8 9 10 11 12 13 |
Given: network f with L layers, loss ℒ_E, reg-weight β, set of layer indices 𝒫={(i,j)}, few-shot sample size m.
Initialize weights W.
for each training step do
(1) Sample mini-batch of size B.
(2) Randomly choose m ≪ B examples to compute CKA.
(3) Forward pass to obtain activations {X₀,…,X_L}.
(4) Compute ℒ_E on full batch.
(5) For each (i,j) ∈ 𝒫:
extract sub-activations X_i ∈ ℝ^{m×p_i}, X_j ∈ ℝ^{m×p_j}
compute CKA_{Linear}(X_i,X_j).
Sum ℒ_C = ∑_{(i,j)} CKA_{Linear}(X_i,X_j).
(6) Total loss ℒ = ℒ_E + β·ℒ_C.
(7) Backpropagate and update W. |
6. Empirical Results and Application Domains
CKA-SR regularization yields quantifiable improvements in multiple regimes.
Sparse Training and Pruning (Ni et al., 2023)
On CIFAR-100 (ResNet32/20), accuracy improvements are consistent at high sparsity. | Method | 70% | 85% | 90% | 95% | 98% | 99.8% | |---------|-----|-----|-----|-----|-----|-------| | LTH | 72.28 | 70.64 | 69.63 | 66.48 | 60.22 | × | | LTH+CKA-SR | 72.67 | 71.90 | 70.11 | 67.07 | 60.36 | × |
Combinations with filter/channel pruning show smaller accuracy drops for pruned models pretrained with CKA-SR.
Bayesian Ensembles and OOD Detection (Smerkous et al., 31 Oct 2024)
HE minimization over CKA kernels offers improved uncertainty quantification and OOD detection. Example (CIFAR-10 vs SVHN, ResNet32, M=10):
- SVGD+RBF: AUROC_PE ≈ 82.5%
- SVGD+HE: AUROC_PE ≈ 89.2%
- Ensemble+OOD_HE: AUROC_PE ≈ 96.5%, MI ≈ 96.6%
Federated Learning (Son et al., 2021)
FedCKA achieves consistently higher accuracy versus prior state-of-the-art on CIFAR-10, CIFAR-100, and Tiny-ImageNet under heterogeneity, with a modest computational overhead.
7. Best Practices and Implementation Insights
- In sparse training, CKA-SR is most effective when combined with standard or penalties; in is typical, and overly large may degrade useful representational flow.
- Apply CKA regularization only to layers where feature similarity is beneficial (early convolutional blocks in federated learning).
- In ensemble settings, use hyperspherical repulsion for stable diversity, and separate inlier/OOD terms for comprehensive calibration.
- CKA-SR requires only intermediate feature maps and is “plug-and-play” with standard training flows; no modifications to pruning schedules or communication protocols are necessary.
CKA-SR regularization systematically broadens representational diversity and compressive efficiency in deep models, with theoretical guarantees via the information bottleneck, empirical robustness across training paradigms, and pragmatic guidelines for scalable implementation (Ni et al., 2023, Smerkous et al., 31 Oct 2024, Son et al., 2021).