The "World Models" paper (Ha et al., 2018 ) presents an approach for training reinforcement learning agents by decoupling representation learning from policy optimization. The core idea is to build a generative model of the environment (the "world model") which the agent can then use to learn a policy, potentially entirely within a simulated version of the environment generated by this model. This methodology aims to address sample efficiency issues often encountered in model-free RL by learning a compressed and predictive representation of the environment's dynamics in an unsupervised manner.
Architecture
The World Model architecture consists of three distinct components:
- Vision Model (V): A Variational Autoencoder (VAE) responsible for compressing high-dimensional observations (e.g., pixels) into a low-dimensional latent vector at each timestep . This component learns a spatial representation of the environment state.
- Memory RNN Model (M): A Mixture Density Network Recurrent Neural Network (MDN-RNN) that models the temporal dynamics of the environment in the latent space. It predicts the probability distribution of the next latent state given the current latent state , the action taken at time , and its own hidden state . It essentially acts as a predictive model of the future compressed states. The hidden state serves as a memory of the past sequence.
- Controller (C): A compact policy network that determines the action to take at each timestep. It receives the current latent state from the VAE and the hidden state from the MDN-RNN as input. The paper utilizes a remarkably simple linear model for this component: .
These components interact sequentially: V processes the observation into . M uses , , and its previous hidden state to produce and predict the distribution of . C uses and to compute the action .
Vision Model (V): Variational Autoencoder
The V component is implemented as a VAE. Its primary function is to learn an encoding that captures the salient spatial features of the observation in a low-dimensional latent space. The VAE consists of an encoder network that maps to the parameters (mean and log-variance ) of a Gaussian distribution, from which is sampled, and a decoder network that attempts to reconstruct the original observation from the latent vector.
The VAE is trained independently and in an unsupervised manner on a dataset of observations collected from the environment (e.g., by a random policy or during initial exploration). The objective function is the standard Evidence Lower Bound (ELBO), which balances reconstruction loss (making similar to ) and a Kullback-Leibler (KL) divergence term that regularizes the latent distribution to be close to a standard Gaussian prior .
1 2 3 4 5 6 7 8 9 10 11 12 13 |
dataset = collect_observations(environment, num_episodes) vae = VAE(input_dim, latent_dim) optimizer = Adam(vae.parameters()) for epoch in range(num_epochs): for obs_batch in dataset: optimizer.zero_grad() z_mean, z_log_var, reconstructed_obs = vae(obs_batch) reconstruction_loss = MSELoss(reconstructed_obs, obs_batch) kl_divergence = -0.5 * sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp()) loss = reconstruction_loss + beta * kl_divergence # beta is KL weight loss.backward() optimizer.step() |
The dimensionality of is significantly smaller than the dimensionality of (e.g., or for pixel observations). This compression is crucial for the efficiency of the subsequent components.
Memory RNN Model (M): MDN-RNN
The M component serves as the predictive core of the world model. It learns a probabilistic model of the environment's dynamics within the latent space defined by V. An LSTM network is used to maintain a hidden state that summarizes the history relevant for prediction.
At each step, the RNN receives the concatenated vector as input and updates its hidden state . The output layer of the RNN then predicts the parameters of a Mixture Density Network (MDN) for the next latent state . An MDN represents the conditional probability distribution as a mixture of Gaussian distributions:
where , and , , are the mixture weights, means, and variances output by the RNN for each of the mixture components. This allows the model to capture multi-modal uncertainties inherent in complex dynamics.
The MDN-RNN is also trained to predict the expected reward and the probability that the episode terminates (done flag) at the next step, usually via separate linear output layers from .
Training M involves minimizing the negative log-likelihood of the observed under the predicted mixture distribution, along with standard losses (e.g., MSE for reward, BCE for done flag) for the auxiliary predictions. The training data consists of sequences generated by V processing rollouts from the actual environment.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
rnn_dataset = collect_rollouts_and_encode(environment, vae, num_episodes) mdn_rnn = MDN_RNN(latent_dim + action_dim, hidden_dim, latent_dim, num_gaussian_mixtures) optimizer = Adam(mdn_rnn.parameters()) for epoch in range(num_epochs): for sequence in rnn_dataset: # sequence = [(z_t, a_t, r_{t+1}, d_{t+1}), ...] hidden_state = mdn_rnn.initial_state() total_loss = 0 for t in range(len(sequence)): z_t, a_t, r_next_true, d_next_true = sequence[t] z_next_true = sequence[t+1].z_t if t+1 < len(sequence) else None # Target z input_vec = concatenate(z_t, a_t) (pi, mu, sigma_sq), r_pred, d_pred_logit, hidden_state = mdn_rnn(input_vec, hidden_state) if z_next_true is not None: # Calculate NLL loss for MDN prediction of z_{t+1} mdn_loss = mdn_nll_loss(pi, mu, sigma_sq, z_next_true) reward_loss = MSELoss(r_pred, r_next_true) done_loss = BCELossWithLogits(d_pred_logit, d_next_true) total_loss += mdn_loss + reward_loss + done_loss optimizer.zero_grad() total_loss.backward() optimizer.step() def mdn_nll_loss(pi, mu, sigma_sq, target): # pi: (batch, num_mixtures) # mu: (batch, num_mixtures, latent_dim) # sigma_sq: (batch, num_mixtures) -> assume isotropic Gaussian for simplicity here # target: (batch, latent_dim) target_expanded = target.unsqueeze(1).expand_as(mu) # (batch, num_mixtures, latent_dim) exponent = -0.5 * torch.sum(((target_expanded - mu) ** 2) / sigma_sq.unsqueeze(-1), dim=2) log_coeff = -0.5 * latent_dim * torch.log(2 * math.pi * sigma_sq) log_gauss_prob = log_coeff + exponent # LogSumExp trick for stability max_log_prob = torch.max(log_gauss_prob, dim=1, keepdim=True)[0] log_prob = max_log_prob + torch.log(torch.sum(pi * torch.exp(log_gauss_prob - max_log_prob), dim=1)) return -torch.mean(log_prob) |
A temperature parameter can be introduced during sampling from the MDN. Sampling involves first sampling a mixture component based on weights , then sampling from the corresponding Gaussian . Higher increases randomness ("dream temperature").
Controller (C): Linear Policy
The controller C is responsible for selecting actions based on the information provided by V and M. A striking aspect of the World Models paper is the simplicity of C. It uses a linear mapping from the concatenated latent state and RNN hidden state to the action :
Here, represents the concatenation of the current compressed observation and the memory state. and are the weight matrix and bias vector of the controller, respectively. For an environment like CarRacing-v0 with and , and a 3-dimensional continuous action space, the controller has only parameters. This is orders of magnitude smaller than typical deep RL policies operating directly on pixels.
The parameters and are optimized using an evolution strategy, specifically CMA-ES (Covariance-Matrix Adaptation Evolution Strategy). CMA-ES is a gradient-free optimization algorithm well-suited for low-dimensional parameter spaces and problems where gradients are difficult to compute or noisy. It works by iteratively sampling sets of candidate parameters (policy parameters in this case) from a multivariate Gaussian distribution, evaluating the fitness (total reward) of each candidate policy by running it in the environment, and updating the mean and covariance matrix of the sampling distribution to favor regions yielding higher fitness.
Training Methodology
The training process proceeds in distinct phases:
- Data Collection: Collect rollouts from the real environment, potentially using a random policy or preliminary agent.
- VAE Training (V): Train the VAE on the collected observations to learn the latent space mapping and reconstruction .
- MDN-RNN Training (M): Encode the collected rollouts into latent space sequences using the trained VAE. Train the MDN-RNN on these sequences to predict given .
- Controller Training (C): Optimize the parameters of the controller C using CMA-ES. The fitness evaluation for each candidate policy involves rolling out the policy and accumulating the total reward. This evaluation can be done either in the real environment or entirely within the learned world model (M).
Training Inside the Dream
A key result of the paper is demonstrating the feasibility of training the controller C entirely within the "dream" generated by the MDN-RNN (M). This process works as follows:
- Initialize the controller C with some parameters .
- Initialize the hidden state of the MDN-RNN. Sample an initial latent state (e.g., from the VAE encoding of an initial observation, or randomly).
- For : a. Compute the action . b. Use the MDN-RNN (M) to predict the distribution of the next latent state and update its hidden state to . c. Sample from the predicted distribution, potentially using the temperature parameter . d. Optionally, M can also predict the reward and termination condition for this simulated step.
- The total accumulated (predicted) reward over the simulated episode serves as the fitness score for the parameter set .
- Use CMA-ES to update the distribution from which controller parameters are sampled based on these fitness scores.
Crucially, this training loop does not interact with the real environment. The agent learns by interacting with its internal simulation provided by M. Once trained, the same controller parameters are transferred back to the agent interacting with the real environment, using the real-time VAE encoding and the MDN-RNN state (updated using real and chosen , but not used for prediction) as input to .
This approach offers significant potential for improving sample efficiency, as the generation of simulated experience within M is computationally much cheaper than interaction with many real-world or complex simulated environments. It also suggests a mechanism for agents to "imagine" and plan.
Experimental Results
The World Models approach was evaluated primarily on two environments:
- CarRacing-v0: A top-down car racing task with continuous control from pixel observations ().
- The VAE used .
- The MDN-RNN used and 5 Gaussian mixtures.
- The controller C was the linear model with 867 parameters.
- Training C entirely within the world model (dream) achieved an average score of over 100 trials, effectively solving the task (threshold is 900). This performance was comparable to or better than contemporary model-free methods like A3C, but achieved with a vastly smaller controller and potentially less interaction with the real environment (though initial data collection is still required for V and M).
- Training C using CMA-ES directly in the real environment yielded similar performance but required significantly more environment interactions for the policy search phase.
- VizDoom Take Cover: A first-person 3D environment where the agent must move left or right to avoid projectiles.
- The VAE used .
- The MDN-RNN used .
- Controller C was again linear.
- Training C in the dream achieved an average survival time of , significantly better than random () but lower than DQN or A3C (). The paper suggests the discrepancy might be due to the world model (M) not perfectly capturing the environment dynamics or reward structure. Training C directly in the environment achieved performance closer to the model-free methods ().
The results highlight the ability of the world model to capture sufficient dynamics for control in some tasks, enabling successful policy learning within the model itself. The small size of the controller C (867 parameters for CarRacing vs. millions for typical deep ConvNet policies) is a notable outcome, suggesting that the representation learned by V and M handles much of the complexity.
Implementation Considerations
- VAE Choice: Using a VAE rather than a standard AE encourages a more structured latent space due to the KL regularization, which might be beneficial for the predictive model M. The parameter weighting the KL term can influence the quality of reconstructions vs. the smoothness of the latent space.
- MDN Complexity: The number of hidden units in the LSTM and the number of Gaussian mixtures in the MDN are hyperparameters affecting M's predictive capacity. Too few may underfit the dynamics, while too many increase computational cost and risk overfitting.
- CMA-ES vs. Gradient-Based Policy Optimization: CMA-ES was chosen for its effectiveness with small parameter spaces and potentially noisy fitness evaluations (especially in the dream). For larger controllers, gradient-based methods (e.g., policy gradients evaluated on dream rollouts) might become necessary, although potentially more complex to stabilize.
- Computational Cost: Training V and M requires significant computation, especially V on large image datasets. However, once trained, generating dream rollouts is relatively fast compared to running complex physics simulators or real-world experiments. The controller training via CMA-ES involves multiple rollouts per generation, which can be parallelized.
- Transferability: The success of transferring the policy trained in the dream back to the real environment hinges on the fidelity of the world model M. If M fails to capture crucial aspects of the dynamics or reward structure, the transferred policy may perform poorly (as potentially observed in the VizDoom experiment). The temperature parameter acts as a regularizer during dream training, potentially mitigating overfitting to the learned model.
Limitations and Future Directions
The original World Models paper demonstrated a compelling proof-of-concept but also had limitations. The fidelity of the VAE reconstruction and, more critically, the MDN-RNN's predictive accuracy, directly impact the effectiveness of training in the dream. Environments with more complex, stochastic, or long-horizon dynamics might pose significant challenges for the MDN-RNN. The reliance on CMA-ES limits the complexity of the controller C that can be effectively trained. Subsequent research has explored replacing components, such as using different generative models (e.g., Transformers, discrete autoencoders) for V and M, employing model-based RL algorithms (like planning or policy gradients) within the learned model instead of CMA-ES, and iterating between model learning and policy learning.
Conclusion
The World Models paper introduced a methodology for reinforcement learning centered around building explicit, generative world models (V and M components) that learn compressed spatial and temporal representations of an environment. This learned model can then be used to efficiently train a compact controller (C), potentially entirely within simulations generated by the world model itself ("dreaming"). The key results demonstrated the feasibility of solving tasks like CarRacing-v0 with a very simple controller trained in this manner, highlighting potential benefits for sample efficiency and providing a framework for separating representation learning from policy optimization in RL.