Do Vision Transformers See Like Convolutional Neural Networks?
Abstract:
The paper "Do Vision Transformers See Like Convolutional Neural Networks?" by Raghu et al. addresses the fundamental question of how Vision Transformers (ViTs), which have recently been demonstrated to perform comparably or even superiorly to Convolutional Neural Networks (CNNs) for image classification tasks, differ in their approach to solving these tasks. By analyzing internal representations, attention mechanisms, and the role of scale in training data, the authors uncover significant insights into the structural and functional differences between these two architectures.
Introduction:
Convolutional Neural Networks have been the cornerstone of visual data processing, leveraging spatial equivariance as an inductive bias through convolutional layers. This inductive bias has enabled CNNs to learn robust visual representations, allowing effective transfer across tasks. However, with the introduction of ViTs, which utilize self-attention mechanisms instead of convolutions, there has been a paradigm shift. ViTs, mirroring their language counterparts, gather global context early through self-attention, posing the question of whether they approach vision tasks like CNNs or follow a distinct path.
Representation Structure of ViTs and CNNs:
A critical aspect of the paper involves comparing the internal representation structures between CNNs (ResNets) and ViTs. Through the use of Centered Kernel Alignment (CKA), a measure that effectively quantifies layer similarity, the paper reveals stark differences:
- ViTs exhibit more uniform representations across layers with high representational similarity from lower to higher layers.
- Conversely, ResNets display stage-wise similarity with lower inter-layer resemblance between distant layers.
The CKA heatmaps demonstrate that ViTs process and propagate information more uniformly compared to the staged processing seen in ResNets.
Local and Global Information Utilization:
The paper explores how ViTs and CNNs incorporate local and global spatial information. By analyzing attention distances, it is evident that:
- ViT's initial layers strike a balance by integrating both local and global information while upper layers focus on global context.
- CNNs incrementally aggregate local information through fixed receptive fields.
Crucially, the ability of ViTs to handle global information early contributes to quantitative differences in learned features compared to CNNs, which is modulated by the scale of pretraining data.
Role of Skip Connections:
A pivotal part of the paper is the focus on the role of skip connections in ViTs. Analyzing the norm ratios of hidden layer representations indicates that:
- Skip connections in ViTs are even more critical than in ResNets, leading to stronger feature propagation.
- A notable phase transition occurs where skip connections initially preserve classification token (CLS) representations but gradually shift to emphasize spatial tokens in higher layers.
This phase transition elucidates the key role of skip connections in maintaining the uniform representation structure in ViTs, as opposed to ResNets where skip connections play a less dominant role.
Spatial Localization and Classification:
Given the importance of spatial information for tasks beyond classification, the paper examines the spatial localization capabilities of ViTs:
- ViTs retain spatial information more effectively than CNNs, which is crucial for tasks like object detection.
- Training methods matter significantly, with ViTs leveraging classification tokens showing stronger preservation of spatial information compared to those trained with global average pooling.
Effect of Scale on Transfer Learning:
Lastly, the paper investigates the impact of dataset scale on learning representations. It is highlighted that:
- Larger pretraining datasets (e.g., JFT-300M) are instrumental for ViTs, particularly for larger models to develop strong intermediate representations essential for transfer learning.
- Models pretrained on large datasets exhibit significantly better performance on downstream tasks compared to those trained on smaller datasets like ImageNet.
Implications and Future Directions:
The insights gained from this comparative paper have broad implications for the design of neural network architectures for vision tasks. The findings point towards the efficacy of early global information aggregation and the critical role of representation propagation mechanisms in achieving robust performance. As such, future research could explore hybrid models that synergize the strengths of ViTs and CNNs, and further investigate the applicability of ViTs in domains requiring fine-grained spatial localization.
In conclusion, while ViTs diverge from CNNs in fundamental ways, particularly in their use of self-attention and representation propagation, their ability to learn effective visual representations underscores their potential in advancing state-of-the-art in computer vision tasks.