Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
102 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Transformers are Sample-Efficient World Models (2209.00588v2)

Published 1 Sep 2022 in cs.LG, cs.AI, and cs.CV

Abstract: Deep reinforcement learning agents are notoriously sample inefficient, which considerably limits their application to real-world problems. Recently, many model-based methods have been designed to address this issue, with learning in the imagination of a world model being one of the most prominent approaches. However, while virtually unlimited interaction with a simulated environment sounds appealing, the world model has to be accurate over extended periods of time. Motivated by the success of Transformers in sequence modeling tasks, we introduce IRIS, a data-efficient agent that learns in a world model composed of a discrete autoencoder and an autoregressive Transformer. With the equivalent of only two hours of gameplay in the Atari 100k benchmark, IRIS achieves a mean human normalized score of 1.046, and outperforms humans on 10 out of 26 games, setting a new state of the art for methods without lookahead search. To foster future research on Transformers and world models for sample-efficient reinforcement learning, we release our code and models at https://github.com/eloialonso/iris.

The paper "Transformers are Sample-Efficient World Models" (Micheli et al., 2022 ) introduces IRIS (Imagination with auto-Regression over an Inner Speech), a model-based reinforcement learning agent designed to address the challenge of sample inefficiency in deep RL. The core idea is to leverage the sequence modeling capabilities of Transformers to build accurate world models that enable effective learning within imagined trajectories, thereby significantly reducing the need for interaction with the real environment.

Methodology: The IRIS Agent

IRIS operates on the principle of learning a world model from limited real-world experience and then training a policy entirely within the simulated environment generated by this model. The process involves iteratively collecting experience, updating the world model, and updating the agent's behavior (policy and value function).

World Model Architecture

The world model comprises two main components: a discrete autoencoder and an autoregressive Transformer.

  1. Discrete Autoencoder: This component is responsible for translating high-dimensional image observations (xtx_t) into a lower-dimensional sequence of discrete latent variables (tokens) zt=(zt1,...,ztK)z_t = (z_t^1, ..., z_t^K), derived from a learned codebook ERN×d\mathcal{E} \in \mathbb{R}^{N \times d}, where NN is the codebook size and dd is the embedding dimension.
    • Encoder (E): A convolutional neural network (CNN) maps an input image xtx_t to a sequence of feature vectors.
    • Vector Quantization: Each feature vector is quantized by finding the nearest neighbor in the codebook E\mathcal{E}, yielding the discrete token indices ztkz_t^k.
    • Decoder (D): A deconvolutional neural network (or transposed CNN) reconstructs the image x^t=D(zt)\hat{x}_t = D(z_t) from the sequence of quantized embedding vectors corresponding to ztz_t.
    • Training: The autoencoder is trained on real environment observations using a composite loss function:

      LAE=L1+βsg[E(x)]e22+γLperceptualL_{AE} = L_1 + \beta ||sg[E(x)] - e||_2^2 + \gamma L_{perceptual}

      where L1=xD(Eq(x))1L_1 = ||x - D(E_q(x))||_1 is the L1 reconstruction loss, ee represents the selected codebook embeddings, Eq(x)E_q(x) is the quantized output of the encoder, sg[]sg[\cdot] denotes the stop-gradient operator, β\beta is a weighting factor for the commitment loss (aligning encoder outputs and embeddings), and LperceptualL_{perceptual} is a perceptual loss (e.g., LPIPS) to improve reconstruction quality.

  2. Autoregressive Transformer (G): This component models the dynamics of the environment in the discrete latent space. It predicts the next state tokens, reward, and termination signal based on the history of state tokens and actions.
    • Input: Sequences of interleaved state tokens and action tokens (z0,a0,z1,a1,...,zt,at)(z_0, a_0, z_1, a_1, ..., z_t, a_t). Actions ata_t are also tokenized.
    • Prediction: The Transformer autoregressively predicts the probability distribution over the next state tokens pG(z^t+1zt,at)p_G(\hat{z}_{t+1} \mid z_{\le t}, a_{\le t}), the reward r^t\hat{r}_t, and the termination probability d^t\hat{d}_t. Specifically, state tokens are predicted sequentially: pG(z^t+1kzt,at,zt+1<k)p_G(\hat{z}_{t+1}^k \mid z_{\le t}, a_{\le t}, z_{t+1}^{<k}).
    • Architecture: A standard decoder-only Transformer architecture (GPT-like) is employed.
    • Training: The Transformer is trained using maximum likelihood estimation on sequences sampled from the real experience replay buffer. The loss function combines cross-entropy loss for predicting the discrete state tokens z^t+1\hat{z}_{t+1} and the termination signal d^t\hat{d}_t, and typically mean squared error (MSE) or cross-entropy loss for the reward r^t\hat{r}_t.

      LWM=logpG(z^t+1)logpG(d^t)+Lreward(r^t,rt)L_{WM} = -\sum \log p_G(\hat{z}_{t+1} \mid \dots) - \log p_G(\hat{d}_t \mid \dots) + L_{reward}(\hat{r}_t, r_t)

Behavior Learning

The policy (π\pi) and value function (VV) are learned entirely within the imagined environment generated by the world model, adapting techniques from DreamerV2.

  1. Imagination Rollouts: Starting from an initial state z0z_0 (encoded from a real observation x0x_0), the agent generates trajectories of length HH purely in imagination. At each step tt:
    • The policy π(atx^t)\pi(a_t | \hat{x}_{\le t}) selects an action based on the history of reconstructed observations x^t=D(zt)\hat{x}_{\le t} = D(z_{\le t}).
    • The Transformer GG predicts the next latent state tokens z^t+1\hat{z}_{t+1}, reward r^t\hat{r}_t, and termination d^t\hat{d}_t.
    • The decoder DD reconstructs the next observation x^t+1=D(z^t+1)\hat{x}_{t+1} = D(\hat{z}_{t+1}).
  2. Actor-Critic Updates:
    • Value Function (Critic): V(x^t)V(\hat{x}_t) estimates the expected sum of future discounted rewards within the imagined trajectory. It is trained to predict bootstrapped λ\lambda-returns, calculated as:

      Vtλ=(1λ)n=1H1λn1Vt(n)+λH1Vt(H)V^\lambda_t = (1-\lambda) \sum_{n=1}^{H-1} \lambda^{n-1} V_t^{(n)} + \lambda^{H-1} V_t^{(H)}

      where Vt(n)=k=tt+n1γktr^k+γnV(x^t+n)V_t^{(n)} = \sum_{k=t}^{t+n-1} \gamma^{k-t} \hat{r}_k + \gamma^n V(\hat{x}_{t+n}). The loss is typically MSE: LV=V(x^t)sg[Vtλ]2L_V = ||V(\hat{x}_t) - sg[V^\lambda_t]||^2.

    • Policy (Actor): π(atx^t)\pi(a_t | \hat{x}_{\le t}) is trained using a REINFORCE-style objective to maximize expected returns, often using the value function for variance reduction (advantage estimation). The objective includes an entropy bonus H[π]H[\pi] to encourage exploration:

      Lπ=Eπ,G[t=0H1logπ(atx^t)(sg[Vtλ]V(x^t))+ηH[π(x^t)]]L_\pi = - \mathbb{E}_{\pi, G}[\sum_{t=0}^{H-1} \log \pi(a_t | \hat{x}_{\le t}) (sg[V^\lambda_t] - V(\hat{x}_t)) + \eta H[\pi(\cdot | \hat{x}_{\le t})]]

Experimental Results and Sample Efficiency

IRIS was evaluated on the Atari 100k benchmark, which restricts agents to learning from only 100,000 steps of interaction with each environment (equivalent to approximately 2 hours of gameplay).

  • Performance: IRIS achieved a mean human normalized score (HNS) of 1.046 across 26 Atari games. This signifies performance slightly exceeding the average human level on this benchmark.
  • Superhuman Performance: The agent achieved scores greater than 1.0 HNS on 10 out of the 26 games evaluated.
  • State-of-the-Art (No Lookahead): At the time of publication, IRIS established a new state-of-the-art performance level for model-based agents without relying on lookahead search algorithms (like Monte Carlo Tree Search used in MuZero or EfficientZero). It significantly outperformed prior methods in mean HNS, median HNS, Interquartile Mean (IQM) HNS (0.501), and Optimality Gap (0.512, lower is better).
    • For comparison, methods like SimPLe, CURL, DrQ, and SPR achieved considerably lower mean HNS scores on Atari 100k.
  • Comparison with Search Methods: While outperforming MuZero (mean HNS 0.562) under the 100k step constraint, IRIS was surpassed by EfficientZero (mean HNS 1.943), highlighting the performance gains achievable by combining learned models with explicit planning/search.
  • Qualitative Analysis: Visualizations provided evidence for the world model's capabilities, including generating coherent and diverse future imagined sequences, achieving pixel-perfect prediction in deterministic environments like Pong, and correctly associating rewards and terminations with relevant in-game events.

Limitations

The authors acknowledged certain limitations:

  • Rare Events: The model may struggle in environments where progress depends on discovering rare events during real-world interaction, as these events might not be sufficiently represented in the initial 100k steps to be learned accurately by the world model.
  • Fine Visual Details: Performance might be limited in games where distinguishing subtle visual details is critical, although this could potentially be addressed by increasing the number of discrete tokens (KK) used to represent each frame (as explored in the paper's appendix).

In conclusion, the IRIS agent demonstrates that employing a Transformer-based architecture within a world model, combined with discrete latent representations via vector quantization, yields a highly sample-efficient reinforcement learning system. By learning accurate dynamics models from limited data, the agent can effectively train its policy primarily through imagination, achieving strong performance on challenging benchmarks like Atari 100k without explicit search mechanisms.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (3)
  1. Vincent Micheli (8 papers)
  2. Eloi Alonso (8 papers)
  3. François Fleuret (78 papers)
Citations (133)
Github Logo Streamline Icon: https://streamlinehq.com
X Twitter Logo Streamline Icon: https://streamlinehq.com