Centroid Transformers: Efficient Clustering
- Centroid Transformers are neural architectures that use a learnable clustering mechanism to compress token representations and lower computational costs.
- They integrate clustering within self-attention, enabling scalable abstraction and enhanced performance across tasks like vision modeling, language summarization, and object detection.
- Empirical studies demonstrate significant MAC reductions and competitive accuracy, confirming their potential in efficient transformer design.
Centroid Transformers are a class of neural architectures that generalize the standard self-attention mechanism by introducing a clustering-driven abstraction layer, reducing the number of output tokens via a learnable set of centroids. This paradigm achieves computational efficiency, scalable abstraction, and improved performance in tasks ranging from vision and language modeling to scientific object detection. Centroid-based modeling is instantiated in several forms, including soft K-means-inspired attention modules, centroid-driven tokenization strategies in masked image modeling (MIM), and hybrid convolution-transformer models for object localization. These approaches have demonstrated empirical advantages in efficiency, accuracy, and robustness across diverse benchmarks.
1. Centroid Attention: Theory and Formulation
Centroid Transformers (Wu et al., 2021) replace the standard quadratic-complexity self-attention, which maps input tokens to outputs, with centroid attention, mapping inputs to centroids (). The core process leverages a differentiable clustering objective, typically a soft K-means energy function:
where are inputs, are centroid embeddings, is a similarity function per attention head, and controls the sharpness. Gradient-based clustering steps are amortized into network layers:
with attention weights and value terms . This formulation connects centroid attention to self-attention by showing its reduction to standard attention under appropriate choices for (rank-1 bilinear form) and value gradient.
In canonical implementations, centroid queries are obtained from via a learned projection, and input keys/values are derived from through distinct projections. Initialization strategies for centroids include random sampling, farthest-point sampling, mean-pooling, or learned projections. A typical update is:
where and . This yields a cost of per attention module, providing substantial savings over for self-attention, particularly when .
2. Centroid-based Tokenization and Masked Modeling
The CCViT framework (Yan et al., 2023) introduces non-parametric centroid tokenization for Vision Transformer (ViT) pre-training. In MIM, image patches are clustered (typically via k-means with ) in embedding space, yielding centroids . Each centroid serves both as a visual token and an approximate patch prototype:
Tokenization avoids training heavy auxiliary models (e.g., VQ-VAE), enabling rapid formation and efficient inference. CCViT incorporates patch masking and centroid replacement to corrupt inputs:
with and denoting masked and replaced indices. Centroids impart patch-level local invariance, as k-means is performed per patch and changes affect only local assignments. CCViT optimizes a joint loss over token classification (cross-entropy) and pixel reconstruction (MSE):
This centroid-driven pipeline establishes non-parametric centroids as viable and efficient visual tokens, with empirical evidence of high token-prediction accuracy (40.8\%) and fast tokenizer creation (150s for patches).
3. Hybrid Centroid Transformers in Structured Detection
CellCentroidFormer (Wagner et al., 2022) demonstrates centroid-based modeling in object localization and detection. Here, the centroid "token" is the parameter vector representing the center, axes, and orientation of detected structures (e.g., cells as ellipses).
The architecture integrates a convolutional backbone (EfficientNet-V2 S) for local feature extraction and MobileViT blocks for global reasoning via transformer-style multi-head self-attention. Prediction heads regress centroid heatmaps and elliptical extents through fully convolutional layers, enabling end-to-end differentiable training.
Centroid regression involves:
- Heatmap extraction for centroids via 1x1 convolution and sigmoid activation.
- Axis regression via additional 1x1 convolutional heads.
- Composite loss: (Huber/smooth L1 losses).
- Optionally, angle can be predicted using an additional regression head.
This hybrid design exploits both local convolutional operations and global attention for spatial context, yielding improved precision, recall, and scores in biological detection tasks.
4. Computational Efficiency and Scaling Properties
Centroid attention modules reduce the computational load of transformer models. Standard self-attention demands complexity for inputs, due to dense pairwise interaction. Centroid attention compresses outputs to centroids, reducing complexity to , with typically set as (text), (vision), or to (point clouds) (Wu et al., 2021). Empirical results show halved multiply-accumulate operations (MACs) and increased maximal batch sizes (e.g., Gigaword summary: MACs 523M 263M, batch size 192 230).
In CCViT (Yan et al., 2023), non-parametric centroid tokenizers are 10–20x faster to create and 2–8x faster at inference than parametric tokenizers, with sub-1GB memory usage. On point clouds, centroid transformers achieve 4× MAC reduction over attention-heavy baselines. This suggests centroid-based methods are well-suited for regimes where output compression and memory efficiency are critical.
5. Empirical Performance Across Domains
Centroid transformers exhibit competitive or superior accuracy compared to standard transformer and CNN-only architectures. Key results include:
| Application | Baseline | Centroid Transformer | Relative Gain |
|---|---|---|---|
| Gigaword Summarization | ROUGE-1 32.99 | ROUGE-1 34.65 | +1.66, 50% encoder MACs |
| ModelNet40 Point Cloud | SepNet-W15: 93.1% | 93.2%, 4× MAC reduction | Comparable, fewer params |
| ImageNet (DeiT-small) | 79.9% @4.7G MAC | 79.8% @3.0G MAC | 36% MAC reduction, minimal drop |
| ImageNet (ViT-B/16, CCViT) | BEiT: 82.9%, MAE: 83.6% | CCViT: 84.3% | +1.4% vs BEiT, +0.7% vs MAE |
| Semantic Segmentation (ADE20K) | BEiT: 44.7% mIoU | CCViT: 48.4%, 51.6%† | +3.7 to +7.0 mIoU |
| Cell Detection (HeLa) | Dual U-Net/CircleNet 0.81-0.84 | 0.91 , 0.89 AP | +5–10% , AP |
† After intermediate ImageNet fine-tuning.
These results indicate that centroid abstraction can yield comparable or improved accuracy in summarization, classification, segmentation, and scientific detection tasks, with substantial reductions in computation.
6. Hyperparameter Choices, Ablation Studies, and Adaptivity
Selection of the number of centroids and initialization strategy significantly influences task performance (Wu et al., 2021). Mean-pooling and learned projections yield higher accuracy than random sampling. Increasing (number of clustering iterations) gives modest accuracy gains but increases computation; is typically sufficient. Too small can bottleneck feature representation; too large diminishes efficiency.
Hyperparameter ablations also demonstrate the robustness of centroid transformers to attention heads and initialization schemes. Task-specific tuning is required for optimal abstraction—adaptive and hierarchical stacking of centroid layers are promising directions. A plausible implication is that centroid transformers may benefit from dynamic, data-dependent abstraction levels in future adaptive models.
7. Applications and Extensions
Centroid transformer architectures generalize beyond standard sequence modeling. Applications include:
- Abstractive summarization: text sequences are compressed via centroid attention, producing more informative summaries (Wu et al., 2021).
- Point cloud classification and reconstruction: centroid attention provides efficient object-level representations and latent capsule abstraction.
- Visual object detection: centroid regression enables compact, robust cell and object representations (Wagner et al., 2022).
- Masked image modeling: CCViT leverages centroid tokenization for efficient ViT pre-training and downstream transfer (Yan et al., 2023).
- Scientific imaging: centroid-extent representation applies to counting pollen, tracking people, and localizing astronomical objects (Wagner et al., 2022).
Potential extensions include adaptive centroid counts, hierarchical abstraction, clustering-augmented losses, and dynamic routing mechanisms bridging capsule networks (Wu et al., 2021).
Centroid Transformers constitute a unifying clustering- and attention-inspired compression paradigm in deep learning, offering computational scalability, robust abstraction, and versatile applicability in computer vision, language modeling, and scientific detection. Empirical studies confirm their effectiveness across diverse benchmarks and tasks.