Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
102 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Diffusion for World Modeling: Visual Details Matter in Atari (2405.12399v2)

Published 20 May 2024 in cs.LG, cs.AI, and cs.CV

Abstract: World models constitute a promising approach for training reinforcement learning agents in a safe and sample-efficient manner. Recent world models predominantly operate on sequences of discrete latent variables to model environment dynamics. However, this compression into a compact discrete representation may ignore visual details that are important for reinforcement learning. Concurrently, diffusion models have become a dominant approach for image generation, challenging well-established methods modeling discrete latents. Motivated by this paradigm shift, we introduce DIAMOND (DIffusion As a Model Of eNvironment Dreams), a reinforcement learning agent trained in a diffusion world model. We analyze the key design choices that are required to make diffusion suitable for world modeling, and demonstrate how improved visual details can lead to improved agent performance. DIAMOND achieves a mean human normalized score of 1.46 on the competitive Atari 100k benchmark; a new best for agents trained entirely within a world model. We further demonstrate that DIAMOND's diffusion world model can stand alone as an interactive neural game engine by training on static Counter-Strike: Global Offensive gameplay. To foster future research on diffusion for world modeling, we release our code, agents, videos and playable world models at https://diamond-wm.github.io.

This paper introduces DIAMOND (DIffusion As a Model Of eNvironment Dreams), a reinforcement learning (RL) agent trained entirely within a world model that uses a diffusion model to predict future observations. The core motivation is that traditional world models often rely on discrete latent representations, which can discard important visual details necessary for optimal decision-making in visually complex environments. Diffusion models, known for high-fidelity image generation, offer a way to model the environment directly in pixel space, potentially preserving these critical details.

Core Idea and Implementation

  1. Diffusion-based World Model: Instead of encoding observations into discrete tokens, DIAMOND models the transition dynamics p(xt+1xt,at)p(x_{t+1} | x_{\le t}, a_{\le t}) directly in image space using a conditional diffusion model. Specifically, it adapts the Elucidating the Design Space of Diffusion-Based Generative Models (EDM) framework (Karras et al., 2022 ).
    • Model Input: The diffusion model Dθ\mathbf{D}_\theta takes a noised version of the next frame xt+1τx_{t+1}^\tau, the diffusion timestep τ\tau, a history of LL previous clean frames xtL+10,,xt0x_{t-L+1}^0, \dots, x_t^0, and the corresponding actions atL+1,,ata_{t-L+1}, \dots, a_t as input.
    • Model Output: It predicts the clean next frame xt+10x_{t+1}^0.
    • Training Objective: The model is trained using a denoising score matching objective, essentially an L2 reconstruction loss between the model's prediction and the ground truth clean frame xt+10x_{t+1}^0:

      L(θ)=E[Dθ(xt+1τ,τ,xt0,at)xt+102]\mathcal{L}(\theta) = E \left[ \Vert \mathbf{D}_\theta(x_{t+1}^\tau, \tau, x_{\le t}^0, a_{\le t}) - x_{t+1}^0 \Vert^2 \right]

      The specific parameterization follows EDM (Eq. 6, 7), which uses preconditioners (cinτ,coutτ,cskipτ,cnoiseτc_\text{in}^\tau, c_\text{out}^\tau, c_\text{skip}^\tau, c_\text{noise}^\tau) to stabilize training across different noise levels σ(τ)\sigma(\tau). The noise level σ(τ)\sigma(\tau) is sampled from a log-normal distribution during training.

    • Network Architecture: A standard 2D U-Net architecture is used for the core denoising network Fθ\mathbf{F}_\theta. Past observations are concatenated channel-wise to the noisy input xt+1τx_{t+1}^\tau. Past actions and the noise level embedding cnoiseτc_\text{noise}^\tau are incorporated using Adaptive Group Normalization (AdaGN) layers within the U-Net's residual blocks.

  2. Choice of Diffusion Framework (EDM vs. DDPM): The paper argues that EDM is crucial for efficiency. Standard DDPM (Ho et al., 2020 ) requires many denoising steps (high NFE) for good quality. EDM's improved preconditioning and training objective (Eq. 7) allow for stable and high-quality generation with very few steps (e.g., NFE=1 to 10). This is critical because the world model needs to be sampled repeatedly during agent training, making low NFE essential for computational feasibility. Experiments show EDM avoids the compounding errors seen with DDPM at low NFE (Figure 4).
  3. Sampling (Imagination): To generate a trajectory ("dream"), the model starts with Gaussian noise xtTx_t^T and iteratively applies the learned denoising function Dθ\mathbf{D}_\theta using a numerical solver for the reverse SDE (Eq. 3) or the corresponding ODE.
    • Solver: Euler's method is found to be effective.
    • Number of Function Evaluations (NFE): While NFE=1 works surprisingly well due to EDM, it can lead to blurry predictions in stochastic environments (averaging over modes). NFE=3 is used in experiments as a balance between capturing sharper details (modes) and computational cost (Figure 5). Each sampling step requires conditioning on the LL previous generated frames and the actions selected by the policy.
  4. Reinforcement Learning Agent:
    • Architecture: A separate actor-critic agent (πϕ,Vϕ\pi_\phi, V_\phi) is used, based on a shared CNN-LSTM backbone. It takes the current generated observation xt0x_t^0 as input.
    • Training: The agent is trained entirely within the imagined environment generated by the diffusion world model. It uses the REINFORCE algorithm with a value baseline Vϕ(xt)V_\phi(x_t) and λ\lambda-returns for the value target to improve sample efficiency within imagination.
      • Value Loss: LV(ϕ)=Eπϕ[t=0H1(Vϕ(xt)sg(Λt))2]\mathcal{L}_V(\phi) = \mathbb{E}_{\pi_\phi} \left[ \sum_{t=0}^{H-1} \big( V_\phi(x_t) - \mathrm{sg} ( \Lambda_t ) \big)^2 \right]
      • Policy Loss: $\mathcal{L}_\pi(\phi) = - \mathbb{E}_{\pi_\phi} \left[ \sum_{t=0}^{H-1} \log\left(\pi_\phi\left(a_t \mid x_{\le t}\right)\right) \operatorname{sg}\left(\Lambda_t - V_\phi\left(x_t\right)\right) + \eta \operatorname{\mathcal{H}\left(\pi_\phi \left(a_t \mid x_{\le t} \right) \right)\right]$
    • Reward/Termination: Separate predictors (RψR_\psi), also using a CNN-LSTM architecture, are trained on the real data to predict rewards and termination signals within the imagination.
  5. Overall Training Loop (Algorithm 1):
    • Collect experience from the real environment using the current policy πϕ\pi_\phi.
    • Update the diffusion world model Dθ\mathbf{D}_\theta using data from the replay buffer D\mathcal{D}.
    • Update the reward/termination predictor RψR_\psi using data from D\mathcal{D}.
    • Train the actor-critic agent (π,V)ϕ(\pi, V)_\phi by generating imagined trajectories using Dθ\mathbf{D}_\theta and RψR_\psi, and applying the RL losses.

Practical Implications and Key Results

  • State-of-the-Art Performance: DIAMOND achieves a mean Human Normalized Score (HNS) of 1.46 on the Atari 100k benchmark, setting a new record for agents trained purely within a world model.
  • Visual Fidelity Matters: Qualitative comparisons (Figure 6) show DIAMOND generates more visually consistent and detailed trajectories than IRIS (Micheli et al., 2022 ), a strong transformer-based discrete world model. For example, DIAMOND correctly renders details like scores updating or distinguishes enemies from rewards more reliably. This improved fidelity likely contributes to better agent performance, especially in games like Asterix, Breakout, and Road Runner where small objects are critical.
  • Computational Efficiency: By using EDM and only NFE=3, the diffusion world model's sampling cost is kept manageable, comparable to or even faster than some discrete models like IRIS (which used NFE=16). DIAMOND uses fewer parameters (13M) than IRIS (30M) or DreamerV3 (18M) (Hafner et al., 2023 ).
  • Drop-in Substitute: Since the world model operates in pixel space, it can directly substitute the real environment for analysis or interaction (e.g., the playable models provided with the code).
  • Code Availability: The implementation is open-sourced at \href{https://github.com/eloialonso/diamond}{https://github.com/eloialonso/diamond}, facilitating replication and extension.

Implementation Considerations and Limitations

  • Conditioning: The current model uses simple frame stacking (L=4L=4) for history. More sophisticated temporal modeling (e.g., using transformers across time like DiT (Peebles et al., 2022 )) might improve long-range dependencies but was found less effective in initial experiments (Appendix G).
  • Reward/Termination Model: These are predicted by a separate model. Integrating them into the diffusion process is non-trivial and left for future work.
  • Environment Scope: Evaluation is primarily on discrete-action Atari environments. Performance in continuous control or more complex 3D environments needs further investigation (though Appendix G provides promising initial results on CS:GO and Driving datasets).
  • Computational Cost: While NFE is low, training still requires significant GPU resources (approx. 2.9 days per game on an RTX 4090).

In summary, DIAMOND demonstrates that using diffusion models as world models is a viable and effective approach for RL, particularly when preserving visual detail is important. The key implementation choices involve using the EDM framework for efficient sampling (low NFE) and conditioning the U-Net appropriately on past states and actions. Its strong performance and improved visual fidelity over discrete models highlight the potential of diffusion for building more capable world models.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (7)
  1. Eloi Alonso (8 papers)
  2. Adam Jelley (7 papers)
  3. Vincent Micheli (8 papers)
  4. Anssi Kanervisto (32 papers)
  5. Amos Storkey (75 papers)
  6. Tim Pearce (24 papers)
  7. François Fleuret (78 papers)
Citations (9)
Youtube Logo Streamline Icon: https://streamlinehq.com