Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
102 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Masked World Models for Visual Control (2206.14244v3)

Published 28 Jun 2022 in cs.RO, cs.AI, cs.CV, and cs.LG

Abstract: Visual model-based reinforcement learning (RL) has the potential to enable sample-efficient robot learning from visual observations. Yet the current approaches typically train a single model end-to-end for learning both visual representations and dynamics, making it difficult to accurately model the interaction between robots and small objects. In this work, we introduce a visual model-based RL framework that decouples visual representation learning and dynamics learning. Specifically, we train an autoencoder with convolutional layers and vision transformers (ViT) to reconstruct pixels given masked convolutional features, and learn a latent dynamics model that operates on the representations from the autoencoder. Moreover, to encode task-relevant information, we introduce an auxiliary reward prediction objective for the autoencoder. We continually update both autoencoder and dynamics model using online samples collected from environment interaction. We demonstrate that our decoupling approach achieves state-of-the-art performance on a variety of visual robotic tasks from Meta-world and RLBench, e.g., we achieve 81.7% success rate on 50 visual robotic manipulation tasks from Meta-world, while the baseline achieves 67.9%. Code is available on the project website: https://sites.google.com/view/mwm-rl.

"Masked World Models for Visual Control" (Seo et al., 2022 ) introduces a framework for visual model-based reinforcement learning (MBRL) that decouples the learning of visual representations from the learning of system dynamics. This approach aims to enhance sample efficiency and improve performance on visual control tasks, particularly those involving fine-grained interactions between robots and small objects, which often pose challenges for end-to-end trained models. The core idea is to train a powerful autoencoder for robust visual representation learning using a masking objective inspired by masked autoencoders (MAE) and concurrently train a latent dynamics model operating on these learned representations.

Decoupled Architecture Overview

The proposed framework, Masked World Model (MWM), consists of two primary components trained separately but updated continually using online data:

  1. Visual Autoencoder: This module learns a compressed latent representation ztz_t from high-dimensional visual observations oto_t. It employs a combination of convolutional layers and a Vision Transformer (ViT) and is trained using a pixel reconstruction objective on masked inputs. An auxiliary reward prediction task is added to encourage the encoding of task-relevant information.
  2. Latent Dynamics Model: This module learns the temporal transitions within the latent space learned by the autoencoder. It predicts the next latent state zt+1z_{t+1} given the current latent state ztz_t and the action ata_t.

This decoupling allows the visual encoder to focus solely on extracting salient features from images, potentially leading to more robust and generalizable representations, while the dynamics model focuses on modeling the system's evolution in the compressed latent space. Planning or policy learning is then performed using the learned latent dynamics model.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
                  +-----------------------+      +---------------------+
Input Observation |                       |      |                     |
      o_t --------> Convolutional Encoder +------> Vision Transformer  +-----> Latent Representation z_t
                  |      (Feature Map f_t)|      |    (Autoencoder)    |
                  +----------+------------+      +----------+----------+
                             |                           ^  |
                             | Masking                   |  | Reconstruction Loss
                             v                           |  v
                  +----------+------------+              |  +---------------------+
                  | Masked Feature Map m_t|              |  | Convolutional Decoder |
                  +----------+------------+              |  +----------+----------+
                             |                           |             |
                             +---------------------------+             | Pixel Reconstruction
                                                                       v
                                                                 Reconstructed o_t

Latent Space z_t ---> +---------------------+ ---> Predicted Latent z_{t+1} -----> Planning / Policy
Action a_t -------> | Latent Dynamics Model | ---> Predicted Reward r_{t+1} ----> Optimization
                  +---------------------+      (Uses z_t for reward prediction)

Figure 1: High-level architecture of the Masked World Model (MWM) framework, illustrating the decoupled visual autoencoder and latent dynamics model.

Visual Representation Learning Details

The visual autoencoder is central to MWM. It processes an input image observation oto_t through several stages:

  1. Convolutional Encoder: Initial convolutional layers extract local features, producing a feature map ftf_t.
  2. Masking: A significant portion (e.g., 75%) of the spatial patches within the feature map ftf_t are randomly masked, yielding a masked feature map mtm_t. This masking strategy is applied to the convolutional features, not the raw pixels, which distinguishes it from standard MAE applied directly to images. The rationale is that convolutional features already capture higher-level spatial information, making the reconstruction task more focused on semantic understanding rather than low-level pixel details.
  3. Vision Transformer (ViT): The unmasked patches from ftf_t are processed by a ViT encoder. The ViT's ability to model long-range dependencies is leveraged to infer the content of the masked regions.
  4. Convolutional Decoder: The output of the ViT, along with learnable mask tokens representing the masked patches, is fed into a convolutional decoder which aims to reconstruct the original, unmasked feature map ftf_t. The final layers then upsample this reconstructed feature map to produce the reconstructed pixel observation o^t\hat{o}_t.

The primary training objective for the autoencoder is the pixel reconstruction loss, typically mean squared error (MSE), between the original observation oto_t and the reconstruction o^t\hat{o}_t:

Lrecon=oto^t2L_{recon} = ||o_t - \hat{o}_t||^2

Critically, an auxiliary reward prediction objective is added to the autoencoder training. The latent representation ztz_t (derived from the ViT output) is used to predict the immediate reward rt+1r_{t+1} received after taking action ata_t. A simple MLP head is typically attached to ztz_t for this prediction.

Lreward=rt+1r^t+1(zt,at)2L_{reward} = ||r_{t+1} - \hat{r}_{t+1}(z_t, a_t)||^2

The total loss for the autoencoder is a weighted sum of the reconstruction and reward prediction losses:

Lautoencoder=Lrecon+λrewardLrewardL_{autoencoder} = L_{recon} + \lambda_{reward} L_{reward}

where λreward\lambda_{reward} is a hyperparameter balancing the two objectives. This auxiliary task explicitly guides the representation ztz_t to encode information relevant for predicting future rewards, making it more suitable for downstream control tasks.

Latent Dynamics Learning

The latent dynamics model operates entirely within the low-dimensional latent space defined by the autoencoder's output ztz_t. Its goal is to model the transition dynamics p(zt+1zt,at)p(z_{t+1} | z_t, a_t). This is typically implemented as a probabilistic model, often a Gaussian distribution parameterized by an MLP:

z^t+1N(μ(zt,at),Σ(zt,at))\hat{z}_{t+1} \sim \mathcal{N}(\mu(z_t, a_t), \Sigma(z_t, a_t))

where μ\mu and Σ\Sigma are the mean and covariance matrix predicted by the MLP, taking the current latent state ztz_t and action ata_t as input.

The dynamics model is trained by maximizing the likelihood of observed transitions (zt,at,zt+1)(z_t, a_t, z_{t+1}) collected from environment interaction and stored in a replay buffer. The objective is typically the negative log-likelihood of the next latent state zt+1z_{t+1} (obtained by encoding the next observation ot+1o_{t+1} using the current state of the autoencoder):

Ldynamics=logp(z^t+1=zt+1zt,at)L_{dynamics} = -\log p(\hat{z}_{t+1} = z_{t+1} | z_t, a_t)

This involves minimizing the difference between the predicted distribution and the encoded next state. In practice, using a fixed diagonal covariance or predicting only the mean (deterministic dynamics) can simplify training.

Online Training and Control

MWM employs an online training scheme where both the autoencoder and the latent dynamics model are continuously updated using data collected during interaction with the environment.

  1. Data Collection: An agent interacts with the environment using a policy derived from the current learned world model (e.g., via planning like Model Predictive Control (MPC) or a learned policy). Experiences (ot,at,rt+1,ot+1)(o_t, a_t, r_{t+1}, o_{t+1}) are stored in a replay buffer.
  2. Model Updates: Batches of transitions are sampled from the replay buffer. For each transition:
    • The autoencoder is updated using the reconstruction loss LreconL_{recon} and the reward prediction loss LrewardL_{reward}. The target zt+1z_{t+1} for the dynamics model is also computed using the updated encoder.
    • The latent dynamics model is updated using the dynamics prediction loss LdynamicsL_{dynamics}, using the latent states ztz_t and zt+1z_{t+1} produced by the autoencoder.
  3. Policy Optimization/Planning: The updated latent dynamics model (and reward predictor) is used to optimize the agent's policy. If using MPC, trajectory optimization methods like the Cross-Entropy Method (CEM) can be employed to plan sequences of actions in the latent space that maximize predicted cumulative rewards. The first action of the best sequence is executed, and the process repeats.

This continuous online updating allows the models to adapt to the data distribution encountered during learning and improves sample efficiency compared to offline training regimes.

Implementation Considerations

  • Network Architectures: The original paper utilized a ResNet-based architecture for the convolutional encoder/decoder and a standard ViT architecture (e.g., ViT-Base) for the transformer component. The latent dynamics model was typically an MLP with several hidden layers. Precise layer configurations, embedding dimensions, and attention heads depend on the specific task complexity and available computational resources.
  • Masking Ratio: A high masking ratio (e.g., 75%) is reported to be effective, forcing the ViT to rely heavily on context and learn meaningful representations.
  • Computational Cost: The inclusion of a ViT significantly increases the computational cost compared to purely convolutional world models, both during training and potentially during inference if frequent re-planning is needed. Training requires substantial GPU memory and compute.
  • Latent Space Dimensionality: The dimension of the latent space ztz_t is a critical hyperparameter. It needs to be large enough to capture relevant state information but small enough for efficient dynamics learning and planning.
  • Stability: Decoupled training can sometimes lead to instabilities if the representation drifts significantly while the dynamics model is trying to adapt. Careful tuning of learning rates and the update frequency of each component is necessary. The use of target networks or Polyak averaging for the encoder when generating targets zt+1z_{t+1} for the dynamics loss can improve stability.
  • Planning Horizon: For MPC-based control, the planning horizon needs to be sufficiently long to solve the task, but longer horizons increase computational cost and can suffer from compounding prediction errors in the latent dynamics model.

The official codebase provides specific implementation details: (sites.google.com).

Experimental Results and Applications

MWM was evaluated on challenging visual robotic manipulation benchmarks, including 50 tasks from Meta-world and several tasks from RLBench.

  • Meta-world: MWM achieved a reported average success rate of 81.7% across 50 visual tasks, significantly outperforming the contemporary state-of-the-art visual MBRL baseline (TD-MPC) which achieved 67.9%. The improvements were particularly notable on tasks requiring precise manipulation of small objects.
  • RLBench: Similar strong performance was reported on RLBench tasks, demonstrating the applicability of the approach across different simulation environments and task types.
  • Sample Efficiency: The method demonstrated competitive or superior sample efficiency compared to model-free methods and other visual MBRL approaches.

These results suggest that the decoupled approach, combining masked feature reconstruction with ViTs and auxiliary reward prediction, leads to more effective visual representations for complex control tasks. The primary application domain is robotic manipulation from pixels, where accurately modeling object interactions and leveraging visual context is crucial.

Strengths and Limitations

Strengths:

  • Improved Representation Quality: Decoupling allows the autoencoder, enhanced by the MAE-style objective and ViT, to learn potentially more robust and informative visual representations compared to end-to-end models.
  • Handling Fine-Grained Interactions: The architecture seems particularly adept at tasks involving precise object manipulation, likely due to the detailed reconstruction objective and the ViT's ability to model spatial relationships.
  • Sample Efficiency: As an MBRL method, it generally offers better sample efficiency than model-free alternatives, which is crucial for real-world robotics.
  • Modularity: The separate components allow for independent improvements or replacements (e.g., using different encoders or dynamics models).

Limitations:

  • Computational Complexity: The use of ViTs makes the autoencoder computationally expensive, potentially limiting deployment on resource-constrained hardware or requiring significant training infrastructure.
  • Training Complexity: Tuning the hyperparameters of the decoupled system (masking ratio, loss weights, learning rates, network sizes) can be complex. Ensuring stability during online updates requires careful engineering.
  • Potential Representation Mismatch: While decoupling is beneficial, there's a risk that the representation learned by the autoencoder (driven by reconstruction and reward prediction) might not be perfectly optimal for the dynamics model's prediction task.
  • Compounding Errors: Like all MBRL methods relying on multi-step prediction, MWM is susceptible to compounding errors in the dynamics model during long-horizon planning.

Conclusion

The Masked World Model framework presents a notable advancement in visual model-based reinforcement learning by effectively decoupling representation learning and dynamics modeling. Leveraging masked autoencoding principles applied to convolutional features processed by a Vision Transformer, combined with an auxiliary reward prediction task, allows for the learning of rich visual representations suitable for complex robotic control tasks. The strong empirical results on challenging benchmarks highlight its potential for improving sample efficiency and performance in visual robotic manipulation, although the computational demands and training complexity remain important considerations for practical deployment.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (7)
  1. Younggyo Seo (25 papers)
  2. Danijar Hafner (32 papers)
  3. Hao Liu (497 papers)
  4. Fangchen Liu (23 papers)
  5. Stephen James (42 papers)
  6. Kimin Lee (69 papers)
  7. Pieter Abbeel (372 papers)
Citations (123)