The paper "Model-Based Reinforcement Learning for Atari" (Kaiser et al., 2019 ) introduces Simulated Policy Learning (SimPLe), an MBRL algorithm designed to enhance sample efficiency in the Atari Learning Environment (ALE) by leveraging learned video prediction models. The primary objective is to achieve competitive performance within a low-data regime, specifically 100,000 environment interactions (equivalent to approximately two hours of real-time play), a stark contrast to the millions of interactions typically required by contemporary model-free reinforcement learning (MFRL) methods.
The SimPLe Algorithm
SimPLe operates through an iterative procedure involving three main stages:
- Data Collection: The current policy interacts with the real environment (the Atari game) to collect a dataset of experience tuples . In the initial iteration, a random policy might be used, while subsequent iterations utilize the policy learned in the previous cycle.
- World Model Training: A world model is trained using the aggregated data . This model learns to predict the next observation and the reward given the current state (typically a stack of recent frames) and the action . The training is supervised.
- Policy Training: A policy is trained using an MFRL algorithm, specifically Proximal Policy Optimization (PPO), entirely within the learned world model . This phase involves generating a large number of simulated interactions within without requiring further interaction with the real environment .
This cycle (Steps 2-3 repeated, potentially interleaved with Step 1) allows the agent to iteratively refine both the world model and the policy. The policy guides data collection towards relevant state-action regions, improving the model's accuracy in those regions, which in turn facilitates better policy learning. The experiments reported use 15 such iterations.
World Model Architectures
A key aspect of SimPLe is the world model , responsible for predicting future frames and rewards. The paper evaluates several architectures:
- Deterministic Model: A convolutional neural network architecture, drawing inspiration from prior work on action-conditional video prediction. It uses skip connections and multiplicative conditioning for the action embedding. Outputs can be raw RGB values (using L2 loss) or a categorical distribution over pixel values (using per-pixel softmax loss). Techniques like clipped loss and scheduled sampling are employed during training.
- Stochastic Model (VAE-based): This model incorporates latent variables to capture stochasticity, following approaches like SV2P. An inference network estimates a posterior , and the prediction model conditions on sampled from this posterior during training. At prediction time, is sampled from a prior, typically . However, the authors report difficulties in tuning the KL divergence weighting and observe divergence between the learned posterior and the fixed prior, hindering predictive performance.
- Novel Stochastic Model with Discrete Latents: This architecture, found to be the most effective, uses discrete latent variables instead of continuous ones.
- An inference network predicts parameters for the posterior distribution over latent variables.
- Samples are drawn and then discretized into a sequence of bits.
- A separate auxiliary network, typically an LSTM, is trained autoregressively to predict the sequence of discrete latent bits .
- During prediction (simulation), this autoregressive network generates the latent bits, replacing sampling from a fixed prior.
- Gradient flow through the discretization step uses a straight-through estimator.
- Techniques like adding uniform noise before discretization and dropout after discretization are used for regularization.
This discrete latent model proved more adept at handling the types of stochasticity present in Atari (e.g., flickering objects, partially observed states, enemy behavior) compared to the deterministic model and avoided the posterior/prior mismatch issues of the standard VAE approach within this framework.
Implementation Considerations
Several practical considerations are crucial for SimPLe's implementation:
- Short Rollouts: To mitigate compounding prediction errors inherent in multi-step generation using learned models, policy training within uses relatively short simulation rollouts (e.g., steps). After each rollout, the simulation is reset to a state sampled from the real experience dataset .
- Value Function Bootstrapping: The potential truncation bias introduced by short rollouts is addressed by bootstrapping the value function. The estimated value of the state reached at the end of the short rollout, , is added to the accumulated reward, providing a more informed target for the policy update.
- Policy Optimization: PPO is used for policy optimization within the simulator. A discount factor is employed, slightly lower than typical values used in MFRL for Atari, potentially increasing tolerance to model inaccuracies over longer horizons.
- Data Scaling: While interaction with the real environment is strictly limited (100k steps), the policy training phase leverages the computational efficiency of the learned model to generate a significantly larger volume of simulated experience (reported as 15.2 million steps across all training iterations). This amplification of experience is central to SimPLe's sample efficiency.
Experimental Results and Sample Efficiency
SimPLe was evaluated on 26 Atari games under the 100k interaction budget constraint and compared against well-tuned implementations of Rainbow (a high-performing MFRL algorithm) and PPO.
- Sample Efficiency: The primary finding is SimPLe's superior sample efficiency. On the majority of the tested games, SimPLe achieved significantly higher scores than Rainbow and PPO at the 100k interaction mark. For instance, on more than half the games, Rainbow required over 200k interactions to match SimPLe's 100k performance. On Freeway, SimPLe demonstrated over 10x greater sample efficiency than Rainbow.
- Performance Comparison: SimPLe at 100k interactions often surpassed the performance of Rainbow and PPO trained for 200k interactions.
- Data Regime Dependence: The advantage of SimPLe is most pronounced in the very low data regime (50k-100k steps). As the amount of real interaction data increases (e.g., 500k, 1M steps), the performance gap narrows, and MFRL methods eventually match or exceed SimPLe's scores, consistent with typical MBRL vs. MFRL trade-offs.
- Model Architecture: Ablation studies confirmed the superiority of the proposed discrete latent stochastic model over the deterministic and VAE-based alternatives within this setup.
- Stochastic Environments: The stochastic world model demonstrated robustness when evaluated under conditions with sticky actions (where the environment randomly repeats the agent's previous action), maintaining performance better than the deterministic counterpart.
Limitations
Despite its success in the low-data regime, the paper acknowledges several limitations:
- Asymptotic Performance: While highly sample-efficient, the final performance achieved by SimPLe generally remains below the state-of-the-art scores obtained by MFRL algorithms trained with substantially more data (tens or hundreds of millions of interactions). This suggests that inaccuracies in the learned world model eventually limit the policy's peak performance.
- Training Variance: The results exhibit high variance across different training runs, indicating sensitivity to initialization or stochasticity in the training process.
- Computational Cost: Training the complex video prediction models is computationally intensive, although inference (simulation) within the learned model is fast. The overall wall-clock time might still be considerable, even if environment interactions are minimized.
Conclusion
The SimPLe algorithm demonstrates that model-based reinforcement learning, predicated on learned video prediction models, can significantly enhance sample efficiency for complex visual control tasks like Atari games, particularly when interaction budgets are severely constrained. By iteratively refining a stochastic world model with discrete latent variables and leveraging it for extensive policy training via simulation, SimPLe achieves performance levels at 100k interactions that often require substantially more data for state-of-the-art model-free methods. While limitations in asymptotic performance and training stability exist, the work establishes video prediction models as a viable component for building sample-efficient agents in visually complex domains.