Papers
Topics
Authors
Recent
2000 character limit reached

Only Token Mean Loss in Vision Transformers

Updated 15 December 2025
  • Only Token Mean Loss (OTM) is a metric that quantifies the importance of individual tokens in Vision Transformers by measuring the change in cross-entropy loss when a token is masked.
  • It employs a supervised feature selection approach by generating pseudo-labels using a loss differential threshold (ρ) and training a dedicated MLP filter to prune less critical tokens.
  • Empirical results on DeiT models demonstrate that OTM filtering reduces compute (FLOPs and latency) while maintaining competitive accuracy.

The only token mean loss (OTM), also termed delta-loss (ΔL), is a metric quantifying the marginal influence of individual tokens on the loss function of a fixed-weights Vision Transformer (ViT) model. Specifically, OTM is used to assess the importance of patch tokens in visual input sequences by measuring the change in cross-entropy loss induced by masking each token in isolation, thus enabling principled data-driven token filtering prior to self-attention. This approach frames token selection as a supervised feature-selection problem and yields efficient, accurate filtered ViTs through a single token pruning stage before any attention computation (Wang et al., 2023).

1. Formal Definition of Only Token Mean Loss

Let X={x1,,xN}X=\{x_1,\dots,x_N\} be the set of NN patch tokens for an input image, with xiRdx_i\in\mathbb{R}^d, where dd is the token embedding dimension. Given a pretrained ViT backbone Transformer()\text{Transformer}(\cdot), compute the network output y^=Transformer(X)\hat{y}=\text{Transformer}(X) and let L=CrossEntropy(y^,y)\mathcal{L}=\text{CrossEntropy}(\hat{y}, y) denote the loss on the original (unmasked) image-label pair (X,y)(X, y). For each token ii, generate a masked input Xi={x1,,xi1,0,xi+1,,xN}X_{-i}=\{x_1,\dots,x_{i-1}, 0, x_{i+1},\dots,x_N\}, i.e., with xix_i replaced by zero. The masked loss is Li=CrossEntropy(Transformer(Xi),y)\mathcal{L}_i = \text{CrossEntropy}(\text{Transformer}(X_{-i}), y). The per-token delta-loss (OTM) is

ΔLiLLi.\Delta\mathcal{L}_i \coloneqq \mathcal{L} - \mathcal{L}_i.

A high positive ΔLi\Delta\mathcal{L}_i indicates that masking xix_i significantly increases the loss, thus xix_i is important; near-zero or negative ΔLi\Delta\mathcal{L}_i implies xix_i is expendable. The expected OTM across the data distribution D\mathcal{D} is

OTM(i)=E(X,y)D[L(X)L(Xi)],\operatorname{OTM}(i) = \mathbb{E}_{(X, y)\sim\mathcal{D}} [\,\mathcal{L}(X) - \mathcal{L}(X_{-i})\,],

computed in practice as the sample mean over the training set (Wang et al., 2023).

2. Token Labeling, Feature Construction, and Filter Architecture

OTM is leveraged for token selection using a two-stage pipeline:

a) Token Labeling: For every image in the training set, obtain L\mathcal{L} via a standard forward pass. For each token ii, mask it, recompute Li\mathcal{L}_i, and set a pseudo-label:

  • label(xi)=1\text{label}(x_i) = 1 if ΔLi>ρ\Delta\mathcal{L}_i > \rho,
  • label(xi)=0\text{label}(x_i) = 0 otherwise, where the threshold ρ\rho (typically 103\approx 10^{-3}) is selected via tuning on validation data.

b) Feature for Each Token: To distinguish tokens in visually similar regions, a global image feature xglobal=1Nk=1Nxkx_{\text{global}} = \frac{1}{N} \sum_{k=1}^N x_k is computed. The input to the filter MLP is xi=[xi;xglobal]R2dx'_i = [x_i ; x_{\text{global}}] \in \mathbb{R}^{2d}.

c) Token-Filter MLP: The MLP consists of three fully connected layers: 2d38410012d \to 384 \to 100 \to 1, with ReLU activations after the first two layers and a sigmoid activation for output. The final output pi=Sigmoid(MLP(xi))(0,1)p_i = \text{Sigmoid}(\text{MLP}(x'_i)) \in (0, 1) is interpreted as the predicted probability of xix_i being important.

d) Loss to Train MLP: Binary cross-entropy loss is used:

LMLP=[label(xi)logpi+(1label(xi))log(1pi)],\mathcal{L}_{\text{MLP}} = -[\text{label}(x_i)\cdot\log p_i + (1-\text{label}(x_i))\cdot\log(1-p_i)],

with the ViT backbone weights held fixed during this phase.

3. Token Filtering Algorithm and Inference Workflow

The methodology is operationalized through two main algorithms (cf. paper pseudocode):

Algorithm Input Output
Token Labeling Pretrained Transformer, training set {X,y}\{X, y\}, threshold ρ\rho Token pseudo-labels label(xi){0,1}\text{label}(x_i)\in\{0,1\}
Filter Training Training tokens {xi}\{x_i\}, pseudo-labels {label(xi)}\{\text{label}(x_i)\} MLP filter parameters WW

Token Labeling: For each image, after a forward pass for L\mathcal{L}, each token is masked, losses recomputed, and OTM thresholded to produce hard labels.

Filter Training: For each minibatch, global token features are pooled and concatenated to individual tokens, MLP predictions computed, and binary cross-entropy minimized until convergence.

Inference: At test time (and during fine-tuning), all input tokens are passed once through the trained MLP filter. Tokens with pip_i below a cutoff are zeroed and dropped for all downstream computation. The filtering is performed once, prior to any self-attention, reducing the computational burden on subsequent layers from O(N2)O(N^2) to O(M2)O(M^2), where MNM \ll N is the number of retained tokens (Wang et al., 2023).

4. Theoretical Context: Feature Selection Perspective

OTM constitutes a direct, "wrapper"-style metric for token utility, analogous to marginal contribution, Shapley-value, or leave-one-out feature-importance in traditional feature selection. By construction, tokens whose masking yields negligible change in ViT loss are considered dispensable. Empirical analysis demonstrates that, for standard vision tasks (e.g., ImageNet1K), the majority of image tokens exhibit ΔL0\Delta\mathcal{L} \approx 0, with a small subset dominating the network’s decisional performance. This suggests efficacy in discarding most tokens without significant adverse effect on predictive accuracy (Wang et al., 2023).

5. Empirical Performance and Comparative Results

Key empirical results on ImageNet1K using DeiT-based backbones:

  • Backbone: DeiT-T (71.4%71.4\% top-1, $1.3$ Gflops baseline)
    • Filtered (DL-ViT-T, ρ=0.002\rho=0.002): FLOPs reduced by 46%46\% to $0.7$ G; throughput +41%41\% to $4,565$ img/s; top-1 drops 0.3%0.3\% to 71.1%71.1\%.
  • Backbone: DeiT-S (79.8%79.8\% top-1, $4.6$ Gflops baseline)
    • Filtered (DL-ViT-S, ρ=0.001\rho=0.001): FLOPs reduced by 15%15\% to $3.9$ G; throughput +7%7\% to $1,602$ img/s; top-1 at 79.6%79.6\% (–0.2%0.2\%).

Comparison to Dynamic-ViT, A-ViT, and E-ViT indicates that DL-ViT offers the best trade-off of accuracy, FLOPs, and latency under comparable conditions.

Ablation studies underscore the necessity of the OTM-driven MLP filter: substituting the filter with a randomly initialized or random-drop mechanism yields a drastic accuracy loss (e.g., 56.6%56.6\% top-1 with a random MLP versus 71.1%71.1\% with OTM-filtered selection), demonstrating the centrality of faithful ΔL-based pseudo-labeling.

6. Significance and Practical Integration

The OTM metric facilitates lightweight, accurate token pruning in ViTs through a purely post-hoc, data- and backbone-driven labeling strategy applicable to pretrained models. The approach obviates the need for end-to-end retraining from scratch, and requires only a single filtering module pre-attention, providing substantial FLOPs and latency gains with minimal accuracy degradation. Empirical evidence confirms that the vast majority of tokens can be safely eliminated under this regime, with retained performance closely tracking the unpruned baseline (Wang et al., 2023). A plausible implication is that principled, loss-driven feature attribution may generalize to other transformer modalities beyond vision.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Whiteboard

Follow Topic

Get notified by email when new papers are published related to Only Token Mean Loss (OTM).