Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
102 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

World Models (1803.10122v4)

Published 27 Mar 2018 in cs.LG and stat.ML

Abstract: We explore building generative neural network models of popular reinforcement learning environments. Our world model can be trained quickly in an unsupervised manner to learn a compressed spatial and temporal representation of the environment. By using features extracted from the world model as inputs to an agent, we can train a very compact and simple policy that can solve the required task. We can even train our agent entirely inside of its own hallucinated dream generated by its world model, and transfer this policy back into the actual environment. An interactive version of this paper is available at https://worldmodels.github.io/

The "World Models" paper (Ha et al., 2018 ) presents an approach for training reinforcement learning agents by decoupling representation learning from policy optimization. The core idea is to build a generative model of the environment (the "world model") which the agent can then use to learn a policy, potentially entirely within a simulated version of the environment generated by this model. This methodology aims to address sample efficiency issues often encountered in model-free RL by learning a compressed and predictive representation of the environment's dynamics in an unsupervised manner.

Architecture

The World Model architecture consists of three distinct components:

  1. Vision Model (V): A Variational Autoencoder (VAE) responsible for compressing high-dimensional observations (e.g., pixels) into a low-dimensional latent vector ztz_t at each timestep tt. This component learns a spatial representation of the environment state.
  2. Memory RNN Model (M): A Mixture Density Network Recurrent Neural Network (MDN-RNN) that models the temporal dynamics of the environment in the latent space. It predicts the probability distribution of the next latent state zt+1z_{t+1} given the current latent state ztz_t, the action ata_t taken at time tt, and its own hidden state hth_t. It essentially acts as a predictive model of the future compressed states. The hidden state hth_t serves as a memory of the past sequence.
  3. Controller (C): A compact policy network that determines the action ata_t to take at each timestep. It receives the current latent state ztz_t from the VAE and the hidden state hth_t from the MDN-RNN as input. The paper utilizes a remarkably simple linear model for this component: at=Wc[zt ht]+bca_t = W_c [z_t \ h_t] + b_c.

These components interact sequentially: V processes the observation oto_t into ztz_t. M uses ztz_t, ata_t, and its previous hidden state ht1h_{t-1} to produce hth_t and predict the distribution of zt+1z_{t+1}. C uses ztz_t and hth_t to compute the action ata_t.

Vision Model (V): Variational Autoencoder

The V component is implemented as a VAE. Its primary function is to learn an encoding zt=E(ot)z_t = E(o_t) that captures the salient spatial features of the observation oto_t in a low-dimensional latent space. The VAE consists of an encoder network that maps oto_t to the parameters (mean μ\mu and log-variance log(σ2)\log(\sigma^2)) of a Gaussian distribution, from which ztz_t is sampled, and a decoder network that attempts to reconstruct the original observation o^t=D(zt)\hat{o}_t = D(z_t) from the latent vector.

The VAE is trained independently and in an unsupervised manner on a dataset of observations collected from the environment (e.g., by a random policy or during initial exploration). The objective function is the standard Evidence Lower Bound (ELBO), which balances reconstruction loss (making o^t\hat{o}_t similar to oto_t) and a Kullback-Leibler (KL) divergence term that regularizes the latent distribution to be close to a standard Gaussian prior N(0,I)\mathcal{N}(0, I).

1
2
3
4
5
6
7
8
9
10
11
12
13
dataset = collect_observations(environment, num_episodes)
vae = VAE(input_dim, latent_dim)
optimizer = Adam(vae.parameters())

for epoch in range(num_epochs):
    for obs_batch in dataset:
        optimizer.zero_grad()
        z_mean, z_log_var, reconstructed_obs = vae(obs_batch)
        reconstruction_loss = MSELoss(reconstructed_obs, obs_batch)
        kl_divergence = -0.5 * sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
        loss = reconstruction_loss + beta * kl_divergence # beta is KL weight
        loss.backward()
        optimizer.step()

The dimensionality of ztz_t is significantly smaller than the dimensionality of oto_t (e.g., D=32D=32 or D=64D=64 for 64×64×364 \times 64 \times 3 pixel observations). This compression is crucial for the efficiency of the subsequent components.

Memory RNN Model (M): MDN-RNN

The M component serves as the predictive core of the world model. It learns a probabilistic model P(zt+1,rt+1,dt+1zt,at,ht)P(z_{t+1}, r_{t+1}, d_{t+1} | z_t, a_t, h_t) of the environment's dynamics within the latent space defined by V. An LSTM network is used to maintain a hidden state hth_t that summarizes the history relevant for prediction.

At each step, the RNN receives the concatenated vector [zt,at][z_t, a_t] as input and updates its hidden state ht=LSTM([zt,at],ht1)h_t = \text{LSTM}( [z_t, a_t], h_{t-1} ). The output layer of the RNN then predicts the parameters of a Mixture Density Network (MDN) for the next latent state zt+1z_{t+1}. An MDN represents the conditional probability distribution P(zt+1zt,at,ht)P(z_{t+1} | z_t, a_t, h_t) as a mixture of Gaussian distributions:

P(zt+1zt,at,ht)=i=1Nπi(xt)N(zt+1μi(xt),σi2(xt)I)P(z_{t+1} | z_t, a_t, h_t) = \sum_{i=1}^{N} \pi_i(x_t) \mathcal{N}(z_{t+1} | \mu_i(x_t), \sigma_i^2(x_t)I)

where xt=(zt,at,ht)x_t = (z_t, a_t, h_t), and πi\pi_i, μi\mu_i, σi2\sigma_i^2 are the mixture weights, means, and variances output by the RNN for each of the NN mixture components. This allows the model to capture multi-modal uncertainties inherent in complex dynamics.

The MDN-RNN is also trained to predict the expected reward rt+1r_{t+1} and the probability that the episode terminates dt+1d_{t+1} (done flag) at the next step, usually via separate linear output layers from hth_t.

Training M involves minimizing the negative log-likelihood of the observed zt+1z_{t+1} under the predicted mixture distribution, along with standard losses (e.g., MSE for reward, BCE for done flag) for the auxiliary predictions. The training data consists of sequences (zt,at,rt+1,dt+1)(z_t, a_t, r_{t+1}, d_{t+1}) generated by V processing rollouts from the actual environment.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
rnn_dataset = collect_rollouts_and_encode(environment, vae, num_episodes)
mdn_rnn = MDN_RNN(latent_dim + action_dim, hidden_dim, latent_dim, num_gaussian_mixtures)
optimizer = Adam(mdn_rnn.parameters())

for epoch in range(num_epochs):
    for sequence in rnn_dataset: # sequence = [(z_t, a_t, r_{t+1}, d_{t+1}), ...]
        hidden_state = mdn_rnn.initial_state()
        total_loss = 0
        for t in range(len(sequence)):
            z_t, a_t, r_next_true, d_next_true = sequence[t]
            z_next_true = sequence[t+1].z_t if t+1 < len(sequence) else None # Target z

            input_vec = concatenate(z_t, a_t)
            (pi, mu, sigma_sq), r_pred, d_pred_logit, hidden_state = mdn_rnn(input_vec, hidden_state)

            if z_next_true is not None:
                # Calculate NLL loss for MDN prediction of z_{t+1}
                mdn_loss = mdn_nll_loss(pi, mu, sigma_sq, z_next_true)
                reward_loss = MSELoss(r_pred, r_next_true)
                done_loss = BCELossWithLogits(d_pred_logit, d_next_true)
                total_loss += mdn_loss + reward_loss + done_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

def mdn_nll_loss(pi, mu, sigma_sq, target):
    # pi: (batch, num_mixtures)
    # mu: (batch, num_mixtures, latent_dim)
    # sigma_sq: (batch, num_mixtures) -> assume isotropic Gaussian for simplicity here
    # target: (batch, latent_dim)

    target_expanded = target.unsqueeze(1).expand_as(mu) # (batch, num_mixtures, latent_dim)
    exponent = -0.5 * torch.sum(((target_expanded - mu) ** 2) / sigma_sq.unsqueeze(-1), dim=2)
    log_coeff = -0.5 * latent_dim * torch.log(2 * math.pi * sigma_sq)
    log_gauss_prob = log_coeff + exponent

    # LogSumExp trick for stability
    max_log_prob = torch.max(log_gauss_prob, dim=1, keepdim=True)[0]
    log_prob = max_log_prob + torch.log(torch.sum(pi * torch.exp(log_gauss_prob - max_log_prob), dim=1))

    return -torch.mean(log_prob)

A temperature parameter τ\tau can be introduced during sampling from the MDN. Sampling zt+1z_{t+1} involves first sampling a mixture component ii based on weights πi\pi_i, then sampling from the corresponding Gaussian N(μi,(τσi)2I)\mathcal{N}(\mu_i, (\tau \sigma_i)^2 I). Higher τ\tau increases randomness ("dream temperature").

Controller (C): Linear Policy

The controller C is responsible for selecting actions based on the information provided by V and M. A striking aspect of the World Models paper is the simplicity of C. It uses a linear mapping from the concatenated latent state ztz_t and RNN hidden state hth_t to the action ata_t:

at=Wc[zt ht]+bca_t = W_c [z_t \ h_t] + b_c

Here, [zt ht][z_t \ h_t] represents the concatenation of the current compressed observation and the memory state. WcW_c and bcb_c are the weight matrix and bias vector of the controller, respectively. For an environment like CarRacing-v0 with zR32z \in \mathbb{R}^{32} and hR256h \in \mathbb{R}^{256}, and a 3-dimensional continuous action space, the controller has only (32+256)×3+3=867(32 + 256) \times 3 + 3 = 867 parameters. This is orders of magnitude smaller than typical deep RL policies operating directly on pixels.

The parameters WcW_c and bcb_c are optimized using an evolution strategy, specifically CMA-ES (Covariance-Matrix Adaptation Evolution Strategy). CMA-ES is a gradient-free optimization algorithm well-suited for low-dimensional parameter spaces and problems where gradients are difficult to compute or noisy. It works by iteratively sampling sets of candidate parameters (policy parameters in this case) from a multivariate Gaussian distribution, evaluating the fitness (total reward) of each candidate policy by running it in the environment, and updating the mean and covariance matrix of the sampling distribution to favor regions yielding higher fitness.

Training Methodology

The training process proceeds in distinct phases:

  1. Data Collection: Collect rollouts (ot,at,rt+1,dt+1)(o_t, a_t, r_{t+1}, d_{t+1}) from the real environment, potentially using a random policy or preliminary agent.
  2. VAE Training (V): Train the VAE on the collected observations oto_t to learn the latent space mapping zt=E(ot)z_t = E(o_t) and reconstruction D(zt)D(z_t).
  3. MDN-RNN Training (M): Encode the collected rollouts into latent space sequences (zt,at,rt+1,dt+1)(z_t, a_t, r_{t+1}, d_{t+1}) using the trained VAE. Train the MDN-RNN on these sequences to predict (zt+1,rt+1,dt+1)(z_{t+1}, r_{t+1}, d_{t+1}) given (zt,at,ht)(z_t, a_t, h_t).
  4. Controller Training (C): Optimize the parameters of the controller C using CMA-ES. The fitness evaluation for each candidate policy involves rolling out the policy and accumulating the total reward. This evaluation can be done either in the real environment or entirely within the learned world model (M).

Training Inside the Dream

A key result of the paper is demonstrating the feasibility of training the controller C entirely within the "dream" generated by the MDN-RNN (M). This process works as follows:

  1. Initialize the controller C with some parameters θ=(Wc,bc)\theta = (W_c, b_c).
  2. Initialize the hidden state h0h_0 of the MDN-RNN. Sample an initial latent state z0z_0 (e.g., from the VAE encoding of an initial observation, or randomly).
  3. For t=0,1,2,...t=0, 1, 2, ...: a. Compute the action at=C(zt,ht;θ)a_t = C(z_t, h_t; \theta). b. Use the MDN-RNN (M) to predict the distribution of the next latent state P(zt+1zt,at,ht)P(z_{t+1} | z_t, a_t, h_t) and update its hidden state to ht+1h_{t+1}. c. Sample zt+1z_{t+1} from the predicted distribution, potentially using the temperature parameter τ\tau. d. Optionally, M can also predict the reward rt+1r_{t+1} and termination condition dt+1d_{t+1} for this simulated step.
  4. The total accumulated (predicted) reward over the simulated episode serves as the fitness score for the parameter set θ\theta.
  5. Use CMA-ES to update the distribution from which controller parameters θ\theta are sampled based on these fitness scores.

Crucially, this training loop does not interact with the real environment. The agent learns by interacting with its internal simulation provided by M. Once trained, the same controller parameters θ\theta are transferred back to the agent interacting with the real environment, using the real-time VAE encoding ztz_t and the MDN-RNN state hth_t (updated using real ztz_t and chosen ata_t, but not used for prediction) as input to CC.

This approach offers significant potential for improving sample efficiency, as the generation of simulated experience within M is computationally much cheaper than interaction with many real-world or complex simulated environments. It also suggests a mechanism for agents to "imagine" and plan.

Experimental Results

The World Models approach was evaluated primarily on two environments:

  • CarRacing-v0: A top-down car racing task with continuous control from pixel observations (64×64×364 \times 64 \times 3).
    • The VAE used zR32z \in \mathbb{R}^{32}.
    • The MDN-RNN used hR256h \in \mathbb{R}^{256} and 5 Gaussian mixtures.
    • The controller C was the linear model with 867 parameters.
    • Training C entirely within the world model (dream) achieved an average score of 906±21906 \pm 21 over 100 trials, effectively solving the task (threshold is 900). This performance was comparable to or better than contemporary model-free methods like A3C, but achieved with a vastly smaller controller and potentially less interaction with the real environment (though initial data collection is still required for V and M).
    • Training C using CMA-ES directly in the real environment yielded similar performance but required significantly more environment interactions for the policy search phase.
  • VizDoom Take Cover: A first-person 3D environment where the agent must move left or right to avoid projectiles.
    • The VAE used zR64z \in \mathbb{R}^{64}.
    • The MDN-RNN used hR512h \in \mathbb{R}^{512}.
    • Controller C was again linear.
    • Training C in the dream achieved an average survival time of 750\approx 750, significantly better than random (210\approx 210) but lower than DQN or A3C (1200\approx 1200). The paper suggests the discrepancy might be due to the world model (M) not perfectly capturing the environment dynamics or reward structure. Training C directly in the environment achieved performance closer to the model-free methods (820\approx 820).

The results highlight the ability of the world model to capture sufficient dynamics for control in some tasks, enabling successful policy learning within the model itself. The small size of the controller C (867 parameters for CarRacing vs. millions for typical deep ConvNet policies) is a notable outcome, suggesting that the representation learned by V and M handles much of the complexity.

Implementation Considerations

  • VAE Choice: Using a VAE rather than a standard AE encourages a more structured latent space due to the KL regularization, which might be beneficial for the predictive model M. The β\beta parameter weighting the KL term can influence the quality of reconstructions vs. the smoothness of the latent space.
  • MDN Complexity: The number of hidden units in the LSTM and the number of Gaussian mixtures in the MDN are hyperparameters affecting M's predictive capacity. Too few may underfit the dynamics, while too many increase computational cost and risk overfitting.
  • CMA-ES vs. Gradient-Based Policy Optimization: CMA-ES was chosen for its effectiveness with small parameter spaces and potentially noisy fitness evaluations (especially in the dream). For larger controllers, gradient-based methods (e.g., policy gradients evaluated on dream rollouts) might become necessary, although potentially more complex to stabilize.
  • Computational Cost: Training V and M requires significant computation, especially V on large image datasets. However, once trained, generating dream rollouts is relatively fast compared to running complex physics simulators or real-world experiments. The controller training via CMA-ES involves multiple rollouts per generation, which can be parallelized.
  • Transferability: The success of transferring the policy trained in the dream back to the real environment hinges on the fidelity of the world model M. If M fails to capture crucial aspects of the dynamics or reward structure, the transferred policy may perform poorly (as potentially observed in the VizDoom experiment). The temperature parameter τ\tau acts as a regularizer during dream training, potentially mitigating overfitting to the learned model.

Limitations and Future Directions

The original World Models paper demonstrated a compelling proof-of-concept but also had limitations. The fidelity of the VAE reconstruction and, more critically, the MDN-RNN's predictive accuracy, directly impact the effectiveness of training in the dream. Environments with more complex, stochastic, or long-horizon dynamics might pose significant challenges for the MDN-RNN. The reliance on CMA-ES limits the complexity of the controller C that can be effectively trained. Subsequent research has explored replacing components, such as using different generative models (e.g., Transformers, discrete autoencoders) for V and M, employing model-based RL algorithms (like planning or policy gradients) within the learned model instead of CMA-ES, and iterating between model learning and policy learning.

Conclusion

The World Models paper introduced a methodology for reinforcement learning centered around building explicit, generative world models (V and M components) that learn compressed spatial and temporal representations of an environment. This learned model can then be used to efficiently train a compact controller (C), potentially entirely within simulations generated by the world model itself ("dreaming"). The key results demonstrated the feasibility of solving tasks like CarRacing-v0 with a very simple controller trained in this manner, highlighting potential benefits for sample efficiency and providing a framework for separating representation learning from policy optimization in RL.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (2)
  1. David Ha (30 papers)
  2. Jürgen Schmidhuber (124 papers)
Citations (918)
Github Logo Streamline Icon: https://streamlinehq.com
Youtube Logo Streamline Icon: https://streamlinehq.com