Only Token Mean Loss in Vision Transformers
- Only Token Mean Loss (OTM) is a metric that quantifies the importance of individual tokens in Vision Transformers by measuring the change in cross-entropy loss when a token is masked.
- It employs a supervised feature selection approach by generating pseudo-labels using a loss differential threshold (ρ) and training a dedicated MLP filter to prune less critical tokens.
- Empirical results on DeiT models demonstrate that OTM filtering reduces compute (FLOPs and latency) while maintaining competitive accuracy.
The only token mean loss (OTM), also termed delta-loss (ΔL), is a metric quantifying the marginal influence of individual tokens on the loss function of a fixed-weights Vision Transformer (ViT) model. Specifically, OTM is used to assess the importance of patch tokens in visual input sequences by measuring the change in cross-entropy loss induced by masking each token in isolation, thus enabling principled data-driven token filtering prior to self-attention. This approach frames token selection as a supervised feature-selection problem and yields efficient, accurate filtered ViTs through a single token pruning stage before any attention computation (Wang et al., 2023).
1. Formal Definition of Only Token Mean Loss
Let be the set of patch tokens for an input image, with , where is the token embedding dimension. Given a pretrained ViT backbone , compute the network output and let denote the loss on the original (unmasked) image-label pair . For each token , generate a masked input , i.e., with replaced by zero. The masked loss is . The per-token delta-loss (OTM) is
A high positive indicates that masking significantly increases the loss, thus is important; near-zero or negative implies is expendable. The expected OTM across the data distribution is
computed in practice as the sample mean over the training set (Wang et al., 2023).
2. Token Labeling, Feature Construction, and Filter Architecture
OTM is leveraged for token selection using a two-stage pipeline:
a) Token Labeling: For every image in the training set, obtain via a standard forward pass. For each token , mask it, recompute , and set a pseudo-label:
- if ,
- otherwise, where the threshold (typically ) is selected via tuning on validation data.
b) Feature for Each Token: To distinguish tokens in visually similar regions, a global image feature is computed. The input to the filter MLP is .
c) Token-Filter MLP: The MLP consists of three fully connected layers: , with ReLU activations after the first two layers and a sigmoid activation for output. The final output is interpreted as the predicted probability of being important.
d) Loss to Train MLP: Binary cross-entropy loss is used:
with the ViT backbone weights held fixed during this phase.
3. Token Filtering Algorithm and Inference Workflow
The methodology is operationalized through two main algorithms (cf. paper pseudocode):
| Algorithm | Input | Output |
|---|---|---|
| Token Labeling | Pretrained Transformer, training set , threshold | Token pseudo-labels |
| Filter Training | Training tokens , pseudo-labels | MLP filter parameters |
Token Labeling: For each image, after a forward pass for , each token is masked, losses recomputed, and OTM thresholded to produce hard labels.
Filter Training: For each minibatch, global token features are pooled and concatenated to individual tokens, MLP predictions computed, and binary cross-entropy minimized until convergence.
Inference: At test time (and during fine-tuning), all input tokens are passed once through the trained MLP filter. Tokens with below a cutoff are zeroed and dropped for all downstream computation. The filtering is performed once, prior to any self-attention, reducing the computational burden on subsequent layers from to , where is the number of retained tokens (Wang et al., 2023).
4. Theoretical Context: Feature Selection Perspective
OTM constitutes a direct, "wrapper"-style metric for token utility, analogous to marginal contribution, Shapley-value, or leave-one-out feature-importance in traditional feature selection. By construction, tokens whose masking yields negligible change in ViT loss are considered dispensable. Empirical analysis demonstrates that, for standard vision tasks (e.g., ImageNet1K), the majority of image tokens exhibit , with a small subset dominating the network’s decisional performance. This suggests efficacy in discarding most tokens without significant adverse effect on predictive accuracy (Wang et al., 2023).
5. Empirical Performance and Comparative Results
Key empirical results on ImageNet1K using DeiT-based backbones:
- Backbone: DeiT-T ( top-1, $1.3$ Gflops baseline)
- Filtered (DL-ViT-T, ): FLOPs reduced by to $0.7$ G; throughput + to $4,565$ img/s; top-1 drops to .
- Backbone: DeiT-S ( top-1, $4.6$ Gflops baseline)
- Filtered (DL-ViT-S, ): FLOPs reduced by to $3.9$ G; throughput + to $1,602$ img/s; top-1 at (–).
Comparison to Dynamic-ViT, A-ViT, and E-ViT indicates that DL-ViT offers the best trade-off of accuracy, FLOPs, and latency under comparable conditions.
Ablation studies underscore the necessity of the OTM-driven MLP filter: substituting the filter with a randomly initialized or random-drop mechanism yields a drastic accuracy loss (e.g., top-1 with a random MLP versus with OTM-filtered selection), demonstrating the centrality of faithful ΔL-based pseudo-labeling.
6. Significance and Practical Integration
The OTM metric facilitates lightweight, accurate token pruning in ViTs through a purely post-hoc, data- and backbone-driven labeling strategy applicable to pretrained models. The approach obviates the need for end-to-end retraining from scratch, and requires only a single filtering module pre-attention, providing substantial FLOPs and latency gains with minimal accuracy degradation. Empirical evidence confirms that the vast majority of tokens can be safely eliminated under this regime, with retained performance closely tracking the unpruned baseline (Wang et al., 2023). A plausible implication is that principled, loss-driven feature attribution may generalize to other transformer modalities beyond vision.