Token Relation Distillation (TRD)
- Token Relation Distillation (TRD) is a method that transfers fine-grained token-level relational structures from a teacher model to a smaller student, improving semantic alignment.
- It constructs a Token-level Relationship Graph (TRG) to capture both intra-instance context and cross-instance token similarities, enabling robust and spatially-aware knowledge transfer.
- Empirical results demonstrate that TRD boosts accuracy and robustness in both CNN and ViT architectures, especially under imbalanced classification scenarios.
Token Relation Distillation (TRD) introduces a methodology for enhancing knowledge distillation by explicitly transferring the token-level relational structure learned by a powerful teacher network to a typically smaller student model. Departing from conventional distillation approaches that emphasize either logits or instance-level relationships, TRD leverages a Token-level Relationship Graph (TRG) to encapsulate both intra-instance semantic context and cross-instance token-wise similarities. This graph-centric strategy enables the student to emulate fine-grained, higher-order semantic dependencies from the teacher, with demonstrated advantages on balanced and imbalanced classification tasks across CNN and ViT architectures (Zhang et al., 2023).
1. Motivation and Theoretical Foundations
Traditional knowledge distillation (KD) as introduced by Hinton et al. (Hinton et al., 2015) focuses on transferring class probability distributions through softened logits:
yielding the standard loss formulation:
Extensions incorporating feature-based alignment or instance-level relational graphs (e.g., RKD [Park et al., 2019], IRG [Liu et al., 2019]) improved transfer, but failed to explicitly model intra-image structure and higher-order semantic patterns.
The central hypothesis motivating TRD is that transferring the rich, token-level relational information—particularly relevant for patch-based architectures (e.g., ViT [Dosovitskiy et al., 2021]) or feature-aggregating CNNs—enables more complete knowledge transfer. This is especially beneficial in long-tailed settings, where rare classes may share semantic micro-patterns best captured at the token level.
Token-level relationships encode:
- Inner-instance semantic context: How patches or regions within an image are related.
- Cross-instance patch-to-patch similarity: How a token in one image relates to tokens in others.
By explicitly distilling this graph-structured data, TRD aims to bridge the capability gap between teacher and student, surpassing instance- or feature-level approaches (Zhang et al., 2023).
2. Construction of the Token-level Relationship Graph (TRG)
Token Representation
- ViT-like networks: Images are partitioned into patches with . Teacher tokens , student tokens .
- CNN-like networks: The penultimate feature map is split into patches tokens .
Every token, whether from teacher or student, is a -dimensional vector used as a node attribute in the graph.
Random Token Sampling
A full batch contains tokens, often computationally unwieldy. TRD samples tokens per image using a shared random mask for both models, yielding tokens for both teacher and student:
Graph Construction
Two attributed graphs, and , share vertex set , with adjacency matrices . For :
A dense token-wise relational similarity matrix can also be formed:
optionally normalized:
3. TRG-based Distillation Objectives
The total loss combines logit-based KD with several graph and token-level objectives:
with hyperparameters .
3.1 Local Preserving Loss
This term matches local neighborhood structure between student and teacher graphs:
3.2 Global Contrastive Loss
Employing an InfoNCE objective, corresponding tokens across teacher and student are aligned, with negatives pushed apart. Student tokens are projected to teacher dimension if :
A dynamic temperature , where is the epoch, is adopted:
3.3 Token-wise Contextual Loss
To transfer inner-instance semantic context, the method matches self-similarity matrices of patch tokens within each image. For penultimate feature map :
This constrains the student to preserve internal patch arrangements consistent with the teacher.
4. Empirical Evaluation
Datasets and Architectures
- Datasets: CIFAR-100, CIFAR-100-LT (imbalance ratios 10, 50, 100), ImageNet-1K, ImageNet-LT (imbalance 10).
- CNN-based models: ResNet-32×4 → ResNet-8×4, ResNet56 → ResNet20, VGG13 → VGG8, WRN-40-2 → WRN-40-1, ShuffleNet → MobileNet.
- ViT-based models: DeiT-Tiny/Small students, ResNet-101 or CeiT-Base teachers.
Training Configurations
- CIFAR-100: SGD with Nesterov ($0.9$ momentum, weight decay), 240 epochs, reduced LR at 150, 180, 210. , , warm-up .
- ImageNet & ViTs: CNNs – 100 epochs, cosine LR schedule, batch 128×4. ViTs – AdamW, LR=, 200 epochs, 10-epoch warm-up. All experiments on 4×RTX3090 GPUs.
Performance Summary
TRD achieves superior accuracy and robustness, notably:
| Setting | TRD Top-1 (%) | Strongest Baseline (%) |
|---|---|---|
| ResNet-32×4 → 8×4 (CIFAR) | 76.42 | HKD: 76.21, KD: 74.12 |
| ShuffleV1←ResNet32×4 | 76.42 | DKD: 76.42, HKD: 75.99 |
| ResNet34→ResNet18 (ImageNet) | 71.31 | KD +1.56% |
| DeiT-Tiny←ResNet101 | 75.5 | KD: 74.8, HKD: 75.2 |
| DeiT-Small←CeiT-Base | 81.8 | HKD: 81.3 |
On long-tailed variants:
- CIFAR-100-LT: TRD degrades less as imbalance increases and can surpass teacher accuracy.
- ImageNet-LT: TRD Top-1 50.32% ( drop) vs. KD 46.70% ( drop).
Ablation studies confirm each loss (, , ) contributes – accuracy, with token-level graphs outperforming instance-level graphs. Dynamic provides smoother optimization and lower embedding divergence.
5. Analysis: Contextualization and Visualizations
Several key insights arise from empirical analysis:
- Larger batch sizes (e.g., 512 tokens) strengthen the representational capacity and graph structure, yielding marginal accuracy improvements.
- t-SNE projections show TRD features are more class-separable than those from KD, IRG, or HKD, indicating successful transfer of fine-grained relational information.
- The addition of each loss component leads to measurable, compositional gains, confirming efficacy of the multi-term objective (Zhang et al., 2023).
6. Limitations and Prospective Directions
Identified challenges and future research include:
- Computational demands: Full construction of k-NN graphs for tokens is resource-intensive. Practical deployments may require efficient or approximate graph building techniques such as locality-sensitive hashing.
- Hyperparameter sensitivity: Optimal settings for , , , , , , must be tuned for datasets and architectures.
- Beyond classification: Extending TRD to object detection, semantic segmentation, or temporal token graphs for video.
- Graph topology learning: Instead of fixed -NN, learning the adjacency matrix during training.
- Self-distillation and multiscale transfer: Applying token-level relations for intra-network feature alignment.
- Cross-modal distillation: Adapting the framework to transfer relational information across modalities (e.g., image-text pairs in vision-LLMs).
7. Related Work and Positioning
TRD advances the progression from basic logit-based KD [Hinton et al., 2015] to feature and relation-based approaches, such as RKD [Park et al., 2019], IRG [Liu et al., 2019], and graph-based distillation [Zhou et al., 2021]. Distinct from prior work that either matches holistic features or instance relationships, TRD’s explicit modeling of token-level graphs enables new forms of fine-grained and spatially-aware knowledge transfer, especially applicable to architectures with patch- or region-based representations (e.g., ViT [Dosovitskiy et al., 2021]). Use of InfoNCE and contrastive paradigms is aligned with the research trajectory outlined in [Tian et al., 2019; Wang & Isola, 2020].
A plausible implication is that the token-centric relational framework may generalize to settings with structured semantic dependencies beyond image classification, suggesting future exploration into relational and multi-modal knowledge transfer regimes (Zhang et al., 2023).