PatchMerger: Efficient Token Merging
- PatchMerger is a learned module that compresses the high-dimensional input token set into a fixed, smaller set, reducing computational complexity in Vision Transformers.
- It uses a learnable weight matrix and softmax-based attention to merge tokens, allowing subsequent layers to process fewer tokens and achieve significant FLOP savings.
- Empirical results demonstrate a 40–60% reduction in FLOPs and runtime, with optimal placement and output token count (e.g., M=8) yielding minimal accuracy loss or slight improvements.
PatchMerger is a learned module designed to reduce computational overhead in Vision Transformers (ViT) by merging a set of input tokens into a fixed, smaller number of output tokens between consecutive encoder layers. By compressing the token set after initial stages of processing, PatchMerger allows the remaining layers to operate on substantially fewer tokens, which results in significant reductions in floating-point operations (FLOPs) and runtime, while matching or sometimes exceeding the original model accuracy in both upstream and downstream tasks.
1. Definition and Core Module Structure
PatchMerger is introduced as an architectural module situated between two Transformer encoder blocks. Its function is to decrease a variable number of input tokens (typically corresponding to image patches) to a fixed output size , where in standard use. After applying PatchMerger, all subsequent layers of the Transformer pipeline process only tokens, which drastically decreases computational complexity when is small relative to .
The standard experimental protocol places a single PatchMerger module halfway through a stack of encoder blocks (typically after block ), with as default in main experiments.
2. Mathematical Formulation
Let denote the matrix of input tokens (each of dimension ). PatchMerger employs a learnable weight matrix , with one column per output token. The output tokens are computed as:
For each input patch with embedding , the output is formed by:
This structure corresponds to a bottom-up attention mechanism with fixed, input-independent queries, where every input token contributes to every output token. No selection or top- routing is performed; all input tokens influence all merged outputs.
3. Implementation and Algorithmic Steps
The PatchMerger operation can be implemented as a matrix-matrix multiplication followed by a softmax, without iterative selection steps. The pseudocode as given is:
1 2 3 4 5 6 7 8 9 10 11 |
Input: X [N×D], desired output tokens M, weight W [D×M]
Output: Y [M×D]
1. For each input token index p=1..N:
z_p ← X[p] # D-dimensional vector
s ← W^T z_p # M-dimensional vector of scores
a ← softmax(s) # normalize to sum to 1
for i in 1..M:
Y[i] ← Y[i] + a[i] * z_p
2. (Optional) apply LayerNorm, residuals, etc., as in a normal transformer block. |
In practice, the accumulation is performed via efficient batched matrix operations, as per the main equation. The module operates without altering the transformer’s residual or normalization structure by default, although optional normalization may be applied.
4. Computational Complexity and Empirical Savings
The cost per standard ViT layer is dominated by self-attention: . Introducing PatchMerger after layers retains this cost for the initial layers, but reduces it for the remaining layers to each. The fraction of total compute used becomes:
Empirical results show that for and , total FLOPs are reduced by 40–60%, with runtime reductions of similar magnitude. For example, Merger ViT-H/14 reduces cost from 4275.9 to 2207.3 ExaFLOPs (51.6% of original), with a runtime drop to 63.5% (Table 1).
| Model | ExaFLOPs | JFT Prec@1 | INet 10-shot | INet Finetuned |
|---|---|---|---|---|
| ViT-H/14 | 4275.9 | 56.56 % | 79.12 % | 87.97 % |
| Merger ViT-H/14 | 2207.3 | 57.27 % | 79.26 % | 87.90 % |
| Merger ViT-H/11 | 3464.2 | 58.15 % | 79.84 % | 88.24 % |
Merged models typically match or exceed unmerged baselines at substantially reduced computational expense; merged ViT-H/11, for example, achieves higher accuracy at only ~80% the cost of ViT-H/14.
5. Training Protocols and Downstream Transfer
All models with and without PatchMerger are pre-trained on the JFT dataset (approximately 300 million images, 18,291 classes), following backbone schedules of Dosovitskiy et al. (2020) and Riquelme et al. (2021). Two downstream transfer protocols are used:
- Few-shot transfer to ImageNet: Only the head is trained (all ViT weights frozen) for 1/5/10 examples per class.
- Full fine-tuning: All parameters, including PatchMerger, are initialized from JFT pretraining and then trained on full ImageNet for a limited number of epochs.
Upstream and downstream results are consistently reported for both setups. The merged models exhibit strong performance, with negligible loss or even minor accuracy increases compared to origin baselines. Placement of the PatchMerger and the target critically affect results.
6. Ablation Studies: Placement and Token Reduction Ratio
Ablation studies indicate:
- Placement: Placing the merger too early (at shallow depth ) detrimentally affects accuracy. Inserted halfway () nearly recovers full accuracy.
- Number of Output Tokens (): Aggressively merging to or $2$ substantially harms performance. Optimal trade-offs are found in , with diminishing gains beyond .
7. Practical Recommendations and Observed Behavior
The recommended usage is a single PatchMerger positioned halfway through the encoder stack with . For larger models or higher input resolutions, initializing with more patches before merging is beneficial: for instance, Merger ViT-H/11 achieves higher accuracy at reduced computational cost compared to ViT-H/14. In Mixture-of-Experts (V-MoE) architectures, PatchMerger primarily benefits the largest models due to reduced expert utilization post-merging. During fine-tuning at higher resolutions, only the pre-merger layers scale in compute, implying even larger savings during inference.
A plausible implication is that PatchMerger enables efficient scaling of ViT models to higher input resolutions and larger parameter counts by strategically reducing token multiplicity after initial processing without significant loss of accuracy.