Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
184 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
45 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Learning and Leveraging World Models in Visual Representation Learning (2403.00504v1)

Published 1 Mar 2024 in cs.CV, cs.AI, and cs.LG

Abstract: Joint-Embedding Predictive Architecture (JEPA) has emerged as a promising self-supervised approach that learns by leveraging a world model. While previously limited to predicting missing parts of an input, we explore how to generalize the JEPA prediction task to a broader set of corruptions. We introduce Image World Models, an approach that goes beyond masked image modeling and learns to predict the effect of global photometric transformations in latent space. We study the recipe of learning performant IWMs and show that it relies on three key aspects: conditioning, prediction difficulty, and capacity. Additionally, we show that the predictive world model learned by IWM can be adapted through finetuning to solve diverse tasks; a fine-tuned IWM world model matches or surpasses the performance of previous self-supervised methods. Finally, we show that learning with an IWM allows one to control the abstraction level of the learned representations, learning invariant representations such as contrastive methods, or equivariant representations such as masked image modelling.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (58)
  1. Self-supervised learning from images with a joint-embedding predictive architecture. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 15619–15629, 2023.
  2. Data2vec: A general framework for self-supervised learning in speech, vision and language. In International Conference on Machine Learning, pages 1298–1312. PMLR, 2022.
  3. Beit: Bert pre-training of image transformers. arXiv preprint arXiv:2106.08254, 2021.
  4. Vicreg: Variance-invariance-covariance regularization for self-supervised learning. arXiv preprint arXiv:2105.04906, 2021.
  5. VICRegl: Self-supervised learning of local visual features. In Alice H. Oh, Alekh Agarwal, Danielle Belgrave, and Kyunghyun Cho, editors, Advances in Neural Information Processing Systems, 2022. https://openreview.net/forum?id=ePZsWeGJXyp.
  6. Emerging properties in self-supervised vision transformers. In ICCV, 2021.
  7. Quality diversity for visual pre-training. In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), pages 5384–5394, October 2023a.
  8. Amortised invariance learning for contrastive self-supervision. In The Eleventh International Conference on Learning Representations, 2023b. https://openreview.net/forum?id=nXOhmfFu5n.
  9. A simple framework for contrastive learning of visual representations. In ICML, pages 1597–1607. PMLR, 2020a.
  10. Context autoencoder for self-supervised representation learning. International Journal of Computer Vision, pages 1–16, 2023.
  11. Exploring simple siamese representation learning. In CVPR, 2020.
  12. Improved baselines with momentum contrastive learning. arXiv preprint arXiv:2003.04297, 2020b.
  13. An empirical study of training self-supervised vision transformers. In ICCV, 2021.
  14. Deconstructing denoising diffusion models for self-supervised learning, 2024.
  15. Text-to-image diffusion models are zero-shot classifiers. arXiv preprint arXiv:2303.15233, 2023.
  16. MMSegmentation Contributors. MMSegmentation: Openmmlab semantic segmentation toolbox and benchmark. https://github.com/open-mmlab/mmsegmentation, 2020.
  17. Randaugment: Practical automated data augmentation with a reduced search space. In H. Larochelle, M. Ranzato, R. Hadsell, M.F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 18613–18624. Curran Associates, Inc., 2020. https://proceedings.neurips.cc/paper_files/paper/2020/file/d85b63ef0ccb114d0a3bb7b7d808028f-Paper.pdf.
  18. Equivariant contrastive learning. arXiv preprint arXiv:2111.00899, 2021.
  19. Imagenet: A large-scale hierarchical image database. In CVPR, 2009.
  20. Equimod: An equivariance module to improve self-supervised learning. arXiv preprint arXiv:2211.01244, 2022.
  21. An image is worth 16x16 words: Transformers for image recognition at scale. In ICLR, 2021.
  22. Dytox: Transformers for continual learning with dynamic token expansion. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2022.
  23. Scalable pre-training of large autoregressive image models. arXiv preprint arXiv:2401.08541, 2024.
  24. Whitening for self-supervised representation learning, 2021.
  25. On the duality between contrastive and non-contrastive self-supervised learning. In The Eleventh International Conference on Learning Representations, 2023a. https://openreview.net/forum?id=kDEL91Dufpa.
  26. Self-supervised learning of split invariant equivariant representations. In Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pages 10975–10996. PMLR, 23–29 Jul 2023b. https://proceedings.mlr.press/v202/garrido23b.html.
  27. Bootstrap your own latent: A new approach to self-supervised learning. In NeurIPS, 2020.
  28. Structuring representation geometry with rotationally equivariant contrastive learning, 2023.
  29. Recurrent world models facilitate policy evolution. In Advances in Neural Information Processing Systems 31, pages 2451–2463. 2018.
  30. Dream to control: Learning behaviors by latent imagination. arXiv preprint arXiv:1912.01603, 2019.
  31. Mastering diverse domains through world models. arXiv preprint arXiv:2301.04104, 2023.
  32. Modem: Accelerating visual model-based reinforcement learning with demonstrations. arXiv preprint arXiv:2212.05698, 2022.
  33. Provable guarantees for self-supervised deep learning with spectral contrastive loss. NeurIPS, 34, 2021.
  34. Momentum contrast for unsupervised visual representation learning. In CVPR, 2020.
  35. Masked autoencoders are scalable vision learners. arXiv preprint arXiv:2111.06377, 2021.
  36. The inaturalist species classification and detection dataset. In CVPR, 2018.
  37. Gaia-1: A generative world model for autonomous driving, 2023.
  38. Soda: Bottleneck diffusion models for representation learning, 2023.
  39. Contrastive learning of structured world models. arXiv preprint arXiv:1911.12247, 2019.
  40. Yann LeCun. A path towards autonomous machine intelligence version 0.9. 2, 2022-06-27. Open Review, 62(1), 2022.
  41. Prefix-tuning: Optimizing continuous prompts for generation, 2021.
  42. Neural manifold clustering and embedding. arXiv preprint arXiv:2201.10000, 2022.
  43. Decoupled weight decay regularization. In International Conference on Learning Representations, 2019. https://openreview.net/forum?id=Bkg6RiCqY7.
  44. Dinov2: Learning robust visual features without supervision. arXiv preprint arXiv:2304.07193, 2023.
  45. Learning Symmetric Embeddings for Equivariant World Models, June 2022. http://arxiv.org/abs/2204.11371. arXiv:2204.11371 [cs].
  46. Finetuned language models are zero-shot learners, 2022.
  47. Sun database: Large-scale scene recognition from abbey to zoo. In 2010 IEEE computer society conference on computer vision and pattern recognition, pages 3485–3492. IEEE, 2010.
  48. Unified perceptual parsing for scene understanding. In ECCV, 2018.
  49. Simmim: A simple framework for masked image modeling, 2022.
  50. Learning interactive real-world simulators. arXiv preprint arXiv:2310.06114, 2023.
  51. Decoupled contrastive learning. arXiv preprint arXiv:2110.06848, 2021.
  52. Large batch training of convolutional networks. arXiv preprint arXiv:1708.03888, 2017.
  53. Cutmix: Regularization strategy to train strong classifiers with localizable features. In Proceedings of the IEEE/CVF international conference on computer vision, pages 6023–6032, 2019.
  54. Barlow twins: Self-supervised learning via redundancy reduction. In ICML, pages 12310–12320. PMLR, 2021.
  55. mixup: Beyond empirical risk minimization. In International Conference on Learning Representations, 2018. https://openreview.net/forum?id=r1Ddp1-Rb.
  56. Instruction tuning for large language models: A survey, 2023.
  57. Learning deep features for scene recognition using places database. In NeurIPS, 2014.
  58. Semantic understanding of scenes through the ade20k dataset. IJCV, 2019.
Citations (15)

Summary

  • The paper introduces Image World Models (IWM) as a novel self-supervised approach that trains a predictor to model transformation effects in latent space.
  • It employs a Vision Transformer encoder and a Transformer-based predictor conditioned on transformation parameters to learn robust equivariant representations.
  • The learned predictor can be finetuned for downstream tasks, outperforming standard encoder finetuning in ImageNet classification and semantic segmentation.

This paper introduces Image World Models (IWM), a self-supervised learning approach based on the Joint-Embedding Predictive Architecture (JEPA) framework (2403.00504). IWM aims to learn not only high-quality visual representations but also a reusable "world model" capable of predicting the effects of transformations in latent space. Unlike traditional self-supervised methods where the predictor or decoder is often discarded after pretraining, IWM focuses on leveraging this learned world model for downstream tasks.

Core Idea: Learning and Leveraging Image World Models

The central concept is to train a predictor network (the world model) to predict the latent representation of a target view (y) given the latent representation of a transformed source view (x) and information about the transformation (a) applied to get from x to y.

  1. Input Views:
    • Target (y): Generated by applying standard augmentations (random crop, horizontal flip, moderate color jitter) to an image I. Destructive augmentations (like grayscale) are avoided to maximize information content.
    • Source (x): Starts from the target y and applies further, potentially stronger transformations including color jitter, destructive augmentations (grayscale, blur, solarization), and patch masking (removing 4 rectangular blocks).
  2. Action (a): Represents the parameters of the transformation needed to reverse the process from x back to y (e.g., color jitter differences, flags for destructive augmentations).
  3. Architecture:
    • Encoder (f_theta): A Vision Transformer (ViT-B/16 used in experiments) encodes the source view x into latent representation z_x.
    • Target Encoder (f_theta^EMA): An exponential moving average (EMA) of the encoder weights encodes the target view y into z_y. This EMA target encoder is crucial for stability and preventing collapse.
    • Predictor (p_phi): The world model, typically a Transformer architecture with adjustable depth and width. It takes the encoded source patches z_x, the transformation parameters a, and positional embeddings for the target patches (m_a) as input.
  4. Objective: The predictor p_phi aims to predict the target representation z_y for the masked patches. The loss is the squared L2 distance between the prediction zy^=pϕ(zx,a,ma)\hat{z_y} = p_\phi(z_x, a, m_a) and the actual target z_y, summed over the target patch indices:

    L(x,y)=iMxCpϕ(fθ(x),axy,ma)ifθEMA(y)i22L(x,y) = \sum_{i\in M_x^C}\| p_\phi\left(f_\theta(x),a_{x\rightarrow y},m_a \right)_i - f_\theta^\text{EMA}(y)_i \|_2^2

    where MxCM_x^C denotes the indices of the target patches (complement of the source mask).

Key Factors for Learning a Strong (Equivariant) World Model

The paper identifies three critical aspects for training a capable world model that can accurately predict transformation effects (termed "equivariant"):

  1. Predictor Conditioning: The predictor must be conditioned on the transformation parameters a. Without conditioning, the model defaults to learning invariant representations (like BYOL/SimSiam). The paper finds "feature conditioning" (concatenating transformation parameters with mask token embeddings and processing through MLPs) works well.
  2. Transformation Complexity: The prediction task needs to be sufficiently difficult. Using strong augmentations (significant color jitter, destructive augmentations like grayscale, blur) forces the predictor to learn meaningful modeling capabilities. Easy transformations don't necessitate a powerful world model.
  3. Predictor Capacity: The predictor (p_phi) needs adequate capacity (depth and width) to model complex transformations. A deeper predictor (e.g., 18 layers vs. 12) is shown to learn equivariance more reliably across different augmentation strengths.

World model quality is evaluated using Mean Reciprocal Rank (MRR), measuring how well the predictor can identify the correct transformed target representation among a bank of distractors.

Leveraging the World Model: Predictor Finetuning

A key contribution is demonstrating that the pretrained predictor (world model) can be effectively reused for downstream tasks, offering an alternative to standard encoder finetuning.

  1. Protocol: Freeze the pretrained encoder (f_theta^EMA, the teacher network performs slightly better). Attach a task-specific head (e.g., linear layer for classification, UperNet head for segmentation) to the output of the pretrained predictor (p_phi). Finetune only the predictor and the task head. The predictor is tasked with predicting a clean, untransformed version of the input (using null transformation parameters a).
  2. Performance:
    • Finetuning the pretrained IWM predictor significantly outperforms finetuning a randomly initialized predictor of the same size, especially for equivariant models (e.g., +1.8% ImageNet Top-1 for IWM18,384Equi\text{IWM}_{18,384}^\text{Equi}). This confirms the predictor learned useful transferable knowledge.
    • Predictor finetuning performance can match or exceed encoder finetuning performance (e.g., IWM18,384Equi\text{IWM}_{18,384}^\text{Equi} predictor finetuning: 83.3% vs. its encoder finetuning: 82.9%; IWM12,384Inv\text{IWM}_{12,384}^\text{Inv} encoder finetuning: 83.3%).
    • It's significantly more parameter-efficient than encoder finetuning, as only the (often smaller) predictor is updated (Figure 2).
    • End-to-end finetuning (both encoder and predictor) yields the highest performance (e.g., 84.4% for IWM18,384Equi\text{IWM}_{18,384}^\text{Equi}).
    • Similar trends hold for semantic segmentation on ADE20k (Table 4).
  3. Multitask Tuning: Inspired by instruction tuning, the predictor can be finetuned for multiple tasks simultaneously by adding learned task-specific tokens as input. A single predictor achieves comparable average performance to separately finetuned predictors, amortizing the parameter cost across tasks (Table 5).

Representation Properties: Abstraction Spectrum

IWM allows controlling the properties of the learned representations by modulating the world model's capability:

  • Invariant World Model: Achieved with weaker conditioning, simpler transformations, or lower predictor capacity. The predictor cannot fully invert transformations, forcing the encoder to learn representations invariant to those transformations (by discarding information). These representations are more abstract/semantic, perform well in linear probing (similar to contrastive methods like MoCo v3), but may have lower peak performance with complex heads. (e.g., IWM12,384Inv\text{IWM}_{12,384}^\text{Inv}).
  • Equivariant World Model: Achieved with strong conditioning, complex transformations, and high predictor capacity. The predictor learns to model transformations accurately, allowing the encoder to retain more detailed information about the input. These representations are richer, perform better with more powerful adaptation methods like predictor finetuning or attentive probing (similar to MIM methods like MAE), and show better OOD generalization (Appendix Table S6). (e.g., IWM18,384Equi\text{IWM}_{18,384}^\text{Equi}).

IWM spans the spectrum between highly abstract (contrastive) and highly detailed (MIM) representations, offering flexibility based on downstream needs.

Implementation Considerations

  • Architecture: ViT-B/16 encoder, Transformer predictor (e.g., 12-18 layers, 384-dim embeddings). Appendix suggests predictor parameters ~30% of encoder parameters is a good starting point for scaling.
  • Pretraining: 300 epochs on ImageNet-1k. AdamW optimizer, cosine LR schedule (1e-3 peak), cosine weight decay schedule (0.04->0.4). Batch size 1024. Asymmetric augmentations (stronger on source x) are generally preferred.
  • Predictor Finetuning: 100 epochs. AdamW, cosine LR schedule (1e-3 peak, divided by 10 for pretrained predictor), WD 0.1. Use teacher encoder, null transformation inputs. Attach attentive head for classification.
  • Encoder Finetuning: 100 epochs. Follows MAE protocol (RandAugment, MixUp, CutMix, AdamW, Layer-wise LR decay).
  • Evaluation: Linear probing (90 epochs, LARS), Attentive probing (90 epochs, AdamW), Finetuning (100 epochs), Segmentation (ADE20k, UperNet head).

In summary, IWM presents a framework for learning visual representations by explicitly training a world model within a JEPA structure. It highlights the importance of conditioning, transformation complexity, and predictor capacity. Crucially, it shows that this learned world model is not just a training artifact but can be effectively finetuned for downstream tasks, offering a competitive and parameter-efficient alternative to encoder finetuning, with the added benefit of multitask capability and control over representation abstraction.