CARE: Class-Aware Representation Refinement
- The paper introduces CARE, a framework that refines instance embeddings using class prototypes to boost discriminative power.
- CARE employs optimal cross-domain assignment and pseudo-label refinement to align source and target features in unsupervised domain adaptation.
- For graph classification, CARE integrates set encoding to generate class prototypes, enhancing both intra-class compactness and inter-class separation.
Class-Aware Representation Refinement (CARE) is a framework for improving discriminative power and generalization in representation learning by explicitly incorporating class information throughout the learning process. Distinct from conventional approaches that treat examples or graphs independently, CARE leverages class prototypes or global class structure to refine instance-level representations, thereby enhancing class separability and mitigating overfitting. The paradigm has been instantiated in both unsupervised domain adaptation for images and supervised graph classification, demonstrating substantial gains via optimal assignment, pseudo-label refinement, and explicit class-aware loss terms (Zhang et al., 2022, Xu et al., 2022).
1. Motivation and Scope
CARE addresses limitations inherent in standard representation learning frameworks, including:
- In unsupervised domain adaptation (UDA), source domain bias and suboptimal feature alignment can degrade pseudo-label quality and hinder adaptation due to mismatched class structures across domains (Zhang et al., 2022).
- In graph classification, standard Graph Neural Networks (GNNs) process each input graph or subgraph independently, neglecting relationships between graphs of the same class and failing to explicitly encourage intra-class clustering or inter-class separation. This can result in overfitting and less transferable embeddings (Xu et al., 2022).
The CARE methodology systematically injects class structure at the representation level, refining latent embeddings using global class prototypes and alignment strategies. In both domains, empirical and theoretical results validate improvements in accuracy, discriminability, and generalization.
2. Core Components and Methodology
CARE is instantiated with distinct but converging methodologies in UDA and graph classification:
2.1 Optimal Cross-Domain Assignment (Image UDA Context)
Given labeled source embeddings and unlabeled target embeddings , CARE clusters target samples via -means to obtain centroids , and computes source class centroids by averaging embeddings per class. A bipartite assignment matrix aligns target clusters with source classes using the Hungarian algorithm:
where .
Pseudo-labels for target samples are assigned via their nearest target centroid and the cluster-to-class mapping defined by (Zhang et al., 2022).
2.2 Class Prototypes and Set Encoders (Graph Classification Context)
For each class , CARE maintains a bag of subgraph embeddings for all training graphs with label . A permutation-invariant Set-Encoder (DeepSets) summarizes into a class prototype :
where is a mean-pooling operator and is an MLP with ReLU activation (Xu et al., 2022).
3. Representation Refinement and Loss Functions
3.1 Injection Mechanism
CARE refines each instance's embedding by concatenation with its class prototype:
where is an MLP+ReLU transformation. This encourages instance embeddings to move closer to their class centroids, increasing within-class compactness (Xu et al., 2022).
3.2 Pseudo-Label Refinement and Confidence Filtering
In UDA, CARE employs a target-only auxiliary network for pseudo-label refinement, trained on target data and current pseudo-labels to avoid source bias. A self-paced learning objective includes only "easy" samples (with high model confidence ), adding harder samples gradually by increasing a threshold parameter per epoch (Zhang et al., 2022). Final confidence filtering retains only samples where model confidence exceeds .
3.3 Class-Aware Loss Design
CARE introduces explicit loss components to promote class structure at the representation level:
- In graph classification:
- Intra-class similarity loss: Average cosine similarity between subgraph embeddings and their class prototype.
- Inter-class similarity loss: Average cosine similarity between different class prototypes.
- Combined class-aware loss: .
- Total loss per batch: (Xu et al., 2022).
- In image UDA:
- Center-to-Center (C2C) MMD: RKHS distance between source and target class centroids.
- Probability-to-Probability (P2P) MMD: RKHS distance between averaged class-conditional predicted probabilities in source and target.
- Total loss includes cross-entropy on source, self-paced pseudo-label refinement, optimal assignment cost, and weighted C2C/P2P alignment terms (Zhang et al., 2022).
4. Integration with Backbone Architectures
CARE is designed to operate as a plug-in on top of standard GNN, CNN, or other feature extraction backbones:
- Graph domain: CARE can be integrated as a "global" block after the last pooling or readout, or hierarchically at each layer if the backbone supports multi-level pooling. The only additional computational cost arises from the subgraph selector and two small MLPs for set encoding and transformation, often resulting in negligible or even reduced total training time due to faster convergence (Xu et al., 2022).
- Image domain (UDA): CARE is implemented using standard architectures such as ResNet-50/101, replacing the FC output layer with a -way classifier head. Auxiliary networks replicate the primary architecture but maintain independent batch normalization layers. No adversarial discriminator is used; domain alignment is realized via the aforementioned MMD losses (Zhang et al., 2022).
5. Theoretical Properties
A key theoretical result is that CARE provides improved generalization guarantees relative to its backbone. VC-dimension analysis shows that when calibrated for equal parameter counts, CARE yields a strictly lower upper bound on the VC dimension:
This reduction implies a smaller upper bound on the generalization gap, offering formal justification for the reduced overfitting observed empirically (Xu et al., 2022).
6. Empirical Evaluation and Results
CARE has been extensively evaluated in both UDA and graph contexts:
| Task Type | Datasets | Baselines | Key Gains |
|---|---|---|---|
| Image UDA | Office-31, ImageCLEF-DA, VisDA-2017, Digit-Five | DAN, JAN, DANN, MADA, RevGrad, etc. | +3–10 pts vs. backbone (e.g., Office-31: 76.1% → 94.0%) (Zhang et al., 2022) |
| Graph Classification | DD, PROTEINS, MUTAG, NCI1, OGB-MOLHIV | GCN, GIN, GraphSAGE, GAT, etc. | 84/88 pairs improved; +1–11% accuracy; +1–5 AUC on OGB (Xu et al., 2022) |
Further analysis confirms high class separability (as measured by Silhouette, Hypothesis Margin), reduced boundary errors, and stable training behavior, with CARE demonstrating resilience to overfitting and effective cluster formation across different architectures and domains.
7. Practical Considerations and Best Practices
Hyperparameter choices are robust within reasonable intervals, e.g., hidden size $32$–$256$, subgraph selector pool ratio $0.25$–$0.75$, and loss weights . For optimization, Adam () and early stopping (patience epochs) are recommended. Class prototype updates are online, requiring no additional momentum or batch-level synchronization. For deployment in GNNs, insertion after global pooling suffices for standard architectures, while hierarchical models benefit from per-layer integration (Xu et al., 2022). CARE imposes minimal computational overhead and is compatible with mainstream deep learning frameworks. In UDA, careful tuning of self-paced learning and assignment parameters further optimizes accuracy and stability (Zhang et al., 2022).
References:
- (Zhang et al., 2022): "CA-UDA: Class-Aware Unsupervised Domain Adaptation with Optimal Assignment and Pseudo-Label Refinement"
- (Xu et al., 2022): "A Class-Aware Representation Refinement Framework for Graph Classification"