The paper "Transformers are Sample-Efficient World Models" (Micheli et al., 2022 ) introduces IRIS (Imagination with auto-Regression over an Inner Speech), a model-based reinforcement learning agent designed to address the challenge of sample inefficiency in deep RL. The core idea is to leverage the sequence modeling capabilities of Transformers to build accurate world models that enable effective learning within imagined trajectories, thereby significantly reducing the need for interaction with the real environment.
Methodology: The IRIS Agent
IRIS operates on the principle of learning a world model from limited real-world experience and then training a policy entirely within the simulated environment generated by this model. The process involves iteratively collecting experience, updating the world model, and updating the agent's behavior (policy and value function).
World Model Architecture
The world model comprises two main components: a discrete autoencoder and an autoregressive Transformer.
- Discrete Autoencoder: This component is responsible for translating high-dimensional image observations () into a lower-dimensional sequence of discrete latent variables (tokens) , derived from a learned codebook , where is the codebook size and is the embedding dimension.
- Encoder (E): A convolutional neural network (CNN) maps an input image to a sequence of feature vectors.
- Vector Quantization: Each feature vector is quantized by finding the nearest neighbor in the codebook , yielding the discrete token indices .
- Decoder (D): A deconvolutional neural network (or transposed CNN) reconstructs the image from the sequence of quantized embedding vectors corresponding to .
Training: The autoencoder is trained on real environment observations using a composite loss function:
where is the L1 reconstruction loss, represents the selected codebook embeddings, is the quantized output of the encoder, denotes the stop-gradient operator, is a weighting factor for the commitment loss (aligning encoder outputs and embeddings), and is a perceptual loss (e.g., LPIPS) to improve reconstruction quality.
- Autoregressive Transformer (G): This component models the dynamics of the environment in the discrete latent space. It predicts the next state tokens, reward, and termination signal based on the history of state tokens and actions.
- Input: Sequences of interleaved state tokens and action tokens . Actions are also tokenized.
- Prediction: The Transformer autoregressively predicts the probability distribution over the next state tokens , the reward , and the termination probability . Specifically, state tokens are predicted sequentially: .
- Architecture: A standard decoder-only Transformer architecture (GPT-like) is employed.
Training: The Transformer is trained using maximum likelihood estimation on sequences sampled from the real experience replay buffer. The loss function combines cross-entropy loss for predicting the discrete state tokens and the termination signal , and typically mean squared error (MSE) or cross-entropy loss for the reward .
Behavior Learning
The policy () and value function () are learned entirely within the imagined environment generated by the world model, adapting techniques from DreamerV2.
- Imagination Rollouts: Starting from an initial state (encoded from a real observation ), the agent generates trajectories of length purely in imagination. At each step :
- The policy selects an action based on the history of reconstructed observations .
- The Transformer predicts the next latent state tokens , reward , and termination .
- The decoder reconstructs the next observation .
- Actor-Critic Updates:
Value Function (Critic): estimates the expected sum of future discounted rewards within the imagined trajectory. It is trained to predict bootstrapped -returns, calculated as:
where . The loss is typically MSE: .
Policy (Actor): is trained using a REINFORCE-style objective to maximize expected returns, often using the value function for variance reduction (advantage estimation). The objective includes an entropy bonus to encourage exploration:
Experimental Results and Sample Efficiency
IRIS was evaluated on the Atari 100k benchmark, which restricts agents to learning from only 100,000 steps of interaction with each environment (equivalent to approximately 2 hours of gameplay).
- Performance: IRIS achieved a mean human normalized score (HNS) of 1.046 across 26 Atari games. This signifies performance slightly exceeding the average human level on this benchmark.
- Superhuman Performance: The agent achieved scores greater than 1.0 HNS on 10 out of the 26 games evaluated.
- State-of-the-Art (No Lookahead): At the time of publication, IRIS established a new state-of-the-art performance level for model-based agents without relying on lookahead search algorithms (like Monte Carlo Tree Search used in MuZero or EfficientZero). It significantly outperformed prior methods in mean HNS, median HNS, Interquartile Mean (IQM) HNS (0.501), and Optimality Gap (0.512, lower is better).
- For comparison, methods like SimPLe, CURL, DrQ, and SPR achieved considerably lower mean HNS scores on Atari 100k.
- Comparison with Search Methods: While outperforming MuZero (mean HNS 0.562) under the 100k step constraint, IRIS was surpassed by EfficientZero (mean HNS 1.943), highlighting the performance gains achievable by combining learned models with explicit planning/search.
- Qualitative Analysis: Visualizations provided evidence for the world model's capabilities, including generating coherent and diverse future imagined sequences, achieving pixel-perfect prediction in deterministic environments like Pong, and correctly associating rewards and terminations with relevant in-game events.
Limitations
The authors acknowledged certain limitations:
- Rare Events: The model may struggle in environments where progress depends on discovering rare events during real-world interaction, as these events might not be sufficiently represented in the initial 100k steps to be learned accurately by the world model.
- Fine Visual Details: Performance might be limited in games where distinguishing subtle visual details is critical, although this could potentially be addressed by increasing the number of discrete tokens () used to represent each frame (as explored in the paper's appendix).
In conclusion, the IRIS agent demonstrates that employing a Transformer-based architecture within a world model, combined with discrete latent representations via vector quantization, yields a highly sample-efficient reinforcement learning system. By learning accurate dynamics models from limited data, the agent can effectively train its policy primarily through imagination, achieving strong performance on challenging benchmarks like Atari 100k without explicit search mechanisms.