Contrastive-JEPA: Enhanced Visual SSL
- Contrastive-JEPA is a self-supervised visual representation learning framework that extends JEPA with VICReg regularization to prevent collapse and enforce mean invariance.
- It combines masked predictive modeling with contrastive-style regularization to achieve faster convergence and superior downstream performance on benchmarks like ImageNet-1K.
- Empirical results demonstrate improved linear probing, fine-tuning accuracy, and better performance in object detection and segmentation compared to other state-of-the-art self-supervised methods.
Contrastive-JEPA (C-JEPA) is a self-supervised visual representation learning framework that augments the Joint-Embedding Predictive Architecture (JEPA) with Variance-Invariance-Covariance Regularization (VICReg) to enhance learning stability and prevent representational collapse. By combining predictive masked modeling and contrastive-style regularization, C-JEPA achieves superior training dynamics and downstream performance compared to prior JEPA-based approaches and other state-of-the-art self-supervised methods (Mo et al., 2024).
1. Foundations: JEPA and Its Limitations
The JEPA framework aims to learn high-quality image representations by predicting masked regions of an input image from its unmasked regions in a latent embedding space. In the canonical Image-based JEPA (I-JEPA), an image is transformed into a context view through masking some patch blocks and a target view containing only masked blocks. The context encoder operates on , providing latent embeddings for visible patches and mask tokens for masked ones. The target encoder , an Exponential Moving Average (EMA) copy of , processes yielding embeddings for masked patches. A predictor attempts to reconstruct the target patch embeddings from the associated masked tokens within the context embedding. The key JEPA loss is a mean squared error between these predicted and actual masked patch embeddings: where is the prediction and is the stop-gradient target.
I-JEPA presents two critical limitations:
- EMA does not prevent "entire collapse", where all learned representations converge to a constant, drastically reducing utility.
- The predictor fails to guarantee accuracy in the mean vector of patch embeddings across views, leading to insufficient mean invariance.
These issues also appear in non-contrastive SSL methods such as SimSiam, for which similar failures of the EMA dynamic have been documented.
2. C-JEPA Architecture and Loss Composition
C-JEPA extends I-JEPA by integrating VICReg's regularization strategies into the model's pipeline. The process encompasses:
- Input Augmentation and Masking: Two augmentations () of input image are produced. Each is split into a context view (masked) and a target view (containing only masked blocks).
- Context Encoder: A Vision Transformer (ViT, e.g., ViT-B/16) encodes into patch embeddings .
- Target Encoder: The EMA copy encodes to target embeddings .
- Predictor Head: Transformer or MLP of moderate depth (6–12 layers) predicts for the masked positions.
- VICReg Projector: A 2–3 layer MLP maps mean context embeddings at masked regions to projected vectors for regularization.
The C-JEPA total loss is a weighted combination: with:
- : patch prediction MSE with stop-gradient on target,
- : encourages sufficient per-dimension variance,
- : enforces mean invariance between views,
- : disables inter-feature correlation.
Standard VICReg weights are , with the regularization block optionally downscaled by a small relative to to balance influence.
3. Training Procedures and Implementation Specifics
C-JEPA is pretrained on unlabeled ImageNet-1K data, with different regimes depending on model size: 600 epochs for larger ViTs (B/L), and 100 epochs for smaller ones (T/S). The optimizer is AdamW, with weight decay from 0.04 to 0.4, batch size 2048, and a learning rate scheduled from to over 15 epochs and cosine-decayed to thereafter.
The EMA momentum parameter is annealed from 0.996 to 1.0. Each view applies four (possibly overlapping) block masks. Predictor depths are 6 for ViT-T/S/B and 12 for ViT-L; embedding dimensions are 384 (192 for ViT-T), and MLP projector output matches encoder dim . Layer normalization and small stabilize the variance term.
4. Theoretical Guarantees: Preventing Collapse and Mean Alignment
By analyzing the JEPA predictor as a linear map and decomposing into eigenspaces (invoking Neural Tangent Kernel theory), it is shown that the stop-gradient predictor loss yields: for eigen-mode . Without the predictor or for unstable , collapse may occur. Removing stop-grad results in
guaranteeing collapse to zero. The inclusion of VICReg’s variance and covariance regularizers enforces per-dimension variance and zero off-diagonal covariance, stabilizing the training dynamics. The invariance regularizer aligns mean embeddings, directly improving the model's ability to capture consistent latent semantics between augmented views.
5. Empirical Performance and Comparative Evaluation
C-JEPA demonstrates improved representation quality relative to I-JEPA and other masked or non-contrastive baselines (MAE, BEiT, iBOT, data2vec, VICReg) across multiple tasks.
Key Results on ImageNet-1K
| Model | Linear Probing | Fine-tune | COCO Box AP | ADE20K mIoU |
|---|---|---|---|---|
| I-JEPA B/16 | 72.9% | 83.5% | 49.9 | 47.6 |
| C-JEPA B/16 | 73.7% | 84.5% | 50.7 | 48.7 |
| I-JEPA L/16 | 77.5% | 85.3% | — | — |
| C-JEPA L/16 | 78.1% | 86.2% | — | — |
Performance gains are systematic across COCO box/mask, ADE20K segmentation, and low-level video/vision tasks such as DAVIS and CLEVR, indicating broader downstream utility (Mo et al., 2024).
6. Ablation Analyses
Ablations confirm the individual and joint importance of the three VICReg components. On ViT-B/16 (100 epochs), base JEPA achieves 63.7% linear probing accuracy; adding variance and covariance yields 68.3%, invariance alone 67.6%, and all three delivers 69.5%. For extended pretraining (600 epochs), the corresponding gains remain: base 72.9%, variance+covariance 73.5%, invariance 73.2%, all terms 73.7%. Variation in VICReg strength reveals an optimal ; too small leads to suboptimal performance, while excessively high weighting induces collapse. Adjusting the invariance weight also shows best results for intermediate values (e.g., 15).
7. Significance, Limitations, and Prospects
C-JEPA advances the field of predictive joint-embedding self-supervision by providing a tractable, empirically robust mechanism for preventing representational collapse and enforcing meaningful invariance in patch-level means. It achieves faster convergence and superior performance on large-scale vision benchmarks compared to both I-JEPA and related baselines. Limitations include the need for careful tuning of VICReg weights—over-regularization can reintroduce collapse—and open questions regarding extension to multimodal data, alternative masking strategies, and further scaling. Directions for improvement include adaptive or learned weighting of the regularization terms and extension to larger or multimodal architectures (Mo et al., 2024).