Stable Gradients for Stable Learning at Scale in Deep Reinforcement Learning (2506.15544v1)
Abstract: Scaling deep reinforcement learning networks is challenging and often results in degraded performance, yet the root causes of this failure mode remain poorly understood. Several recent works have proposed mechanisms to address this, but they are often complex and fail to highlight the causes underlying this difficulty. In this work, we conduct a series of empirical analyses which suggest that the combination of non-stationarity with gradient pathologies, due to suboptimal architectural choices, underlie the challenges of scale. We propose a series of direct interventions that stabilize gradient flow, enabling robust performance across a range of network depths and widths. Our interventions are simple to implement and compatible with well-established algorithms, and result in an effective mechanism that enables strong performance even at large scales. We validate our findings on a variety of agents and suites of environments.
Summary
- The paper identifies that non-stationarity in deep RL exacerbates gradient pathologies, causing performance collapse in larger, deeper networks.
- It introduces multi-skip residual connections that inject encoder features at every MLP layer alongside a Kronecker-factored second-order optimizer to maintain stable gradients.
- Combined interventions significantly boost scalability, achieving median improvements of 83% on Atari and enhancing performance across diverse algorithms and environments.
Scaling deep reinforcement learning (deep RL) models to larger network sizes, analogous to the successful scaling trends in supervised and generative learning, often leads to performance degradation rather than improvement. This paper investigates the underlying causes of this phenomenon, arguing that it stems from the interplay between the inherent non-stationarity of RL training and architectural gradient pathologies that are exacerbated at scale. The authors propose simple, practical interventions focused on stabilizing gradient flow and demonstrate their effectiveness across various agents and environments.
The core problem in scaling deep RL, as identified in the paper, is that the dynamic nature of RL training—where the policy, data distribution, and training targets constantly change—interacts negatively with the known issues of gradient propagation in deep networks, such as vanishing or exploding gradients and ill-conditioned Hessians. Unlike supervised learning with static datasets, RL's non-stationarity means the optimization landscape is constantly shifting, making stable gradient updates crucial but harder to achieve. As networks get deeper and wider, these gradient pathologies become more pronounced, leading to issues like inactive neurons, reduced representation capacity, and flat loss landscapes, ultimately hindering effective learning and causing performance collapse.
The authors diagnose these issues through empirical analysis in both stationary and non-stationary supervised learning (using CIFAR-10 with shuffled labels) and deep RL (using PQN on Atari). They show that in non-stationary settings, deeper and wider networks exhibit a marked degradation in gradient magnitudes, a higher fraction of dormant neurons, reduced effective rank of representations, and near-zero Hessian trace (indicating flat loss curvature), all correlating with performance failure. This suggests that scaling exacerbates the inability of standard network architectures and optimizers to maintain useful gradient signals under non-stationary conditions.
Based on this diagnosis, the paper proposes two main types of interventions aimed at stabilizing gradient flow:
- Multi-Skip Residuals: Standard residual connections (1512.03385) help gradients bypass layers, mitigating vanishing gradients. However, in deep RL's highly non-stationary setting, single or dual-layer skips might be insufficient. The authors propose a "multi-skip" architecture for the MLP component, where features from the convolutional encoder are broadcast directly to all subsequent MLP layers. This creates direct shortcut paths for gradients from the final layers back to the shared encoder and intermediate MLP layers, ensuring that gradient information can propagate effectively regardless of MLP depth.
- Implementation: This involves adding the flattened output of the convolutional encoder to the output of each hidden layer in the MLP before applying the activation function. The MLP layers still process the output of the previous MLP layer, but the encoder features are added as a residual at every step.
- Second-Order Optimizers: Standard first-order optimizers like Adam (1412.6980) or RAdam (1908.03265) rely only on first-order gradient information and adaptive learning rates based on historical gradient statistics. They are curvature-agnostic and can struggle in complex, non-stationary landscapes. Second-order methods, which incorporate information about the loss function's curvature (like the Hessian or Fisher Information Matrix), can provide more stable and directionally aware updates. The authors propose using the Kronecker-factored optimizer (Kron), an approximation of the K-FAC optimizer (1503.05671), which uses Kronecker-factored approximations of the Fisher Information Matrix to precondition gradients.
- Implementation: Implementing Kron involves replacing the standard optimizer (e.g., Adam, RAdam, AdamW) with a Kron implementation. This typically requires computing or approximating the Fisher Information Matrix or Hessian and its inverse (or pseudo-inverse) and using it to precondition the gradient updates. The Kron optimizer specifically uses Kronecker-factored approximations to reduce computational cost compared to full second-order methods, though it still introduces additional overhead compared to first-order methods.
The paper demonstrates the effectiveness of these interventions using PQN on the Atari-10 suite. Individually, both the multi-skip architecture and the Kron optimizer significantly improve the scalability of PQN, preventing the performance collapse observed with standard MLPs and RAdam at increased depths and widths (Figure 4). When combined, these interventions yield even greater benefits. An augmented PQN agent using both multi-skip residuals and the Kron optimizer on the full ALE suite (57 games) outperforms the baseline PQN in 90% of environments, with a median relative improvement of 83.27% (Figure 5). This combined approach also enables high accuracy and rapid adaptation in the non-stationary supervised learning setting, maintaining stable gradient flow across varying scales (Figure 6).
The authors further validate the generality of their findings by applying similar interventions to other algorithms and environments:
- PPO on ALE and Isaac Gym: Augmenting PPO with multi-skip connections and the Kron optimizer improves performance on the full ALE (31.40% median improvement, outperforms baseline in 83.64% of games) and prevents performance collapse in continuous control tasks (Cartpole, Anymal) in Isaac Gym at larger scales (Figure 7).
- Impala CNN Encoder: Replacing the standard Atari CNN encoder with a richer Impala CNN combined with the proposed MLP interventions (multi-skip, LayerNorm, Kron) allows both PQN and PPO to effectively leverage the expressive capacity of the larger encoder and scale the MLP component, unlike baselines that collapse (Figure 8).
- Simba on DMC: Augmenting the Simba architecture (which already includes residual blocks and LayerNorm) with the Kron optimizer instead of AdamW prevents the performance degradation observed with AdamW as model size increases on challenging DeepMind Control Suite tasks (Humanoid, Dog), enabling stable and improved learning with larger networks for both SAC and DDPG (Figure 9, Figures 12-17 in Appendix).
- Other Optimizers: While Kron performed best, the paper also shows ablations comparing RAdam, AdaBelief, Shampoo, and Apollo with PPO and PQN on Atari-10. Only Kron consistently enabled stable training at scale, suggesting that specific properties of second-order or curvature-aware optimization are key in this regime (Figure 11 in Appendix).
Implementation Considerations:
- Architectural Changes: Implementing multi-skip residuals for the MLP requires modifying the forward pass to add the initial flattened feature vector to the output of each subsequent dense layer. This is relatively straightforward in standard deep learning frameworks.
- Optimizer Integration: Using the Kron optimizer is the primary complexity. While more efficient than full second-order methods, it still requires computing approximations of the Fisher Information Matrix, which adds computational overhead compared to Adam or RAdam. This overhead increases with network size, as shown in Table 8 (Appendix), with training times potentially several times longer for larger models compared to Adam. Efficient implementations of Kronecker-factored approximations are necessary.
- Layer Normalization: The paper emphasizes the role of Layer Normalization (1607.06450), which is used by default in PQN and Simba and often integrated with residual connections. LayerNorm helps stabilize activations and gradient norms layer-wise, complementing the other interventions.
- Computational Resources: Experiments were conducted on single-GPU setups (NVIDIA RTX 8000). While using vectorized environments and efficient algorithms helps, the increased computation of Kron means scaling to very large models might still be resource-intensive without further optimization or distributed training.
- Hyperparameter Tuning: The paper provides hyperparameters (Appendix E), which are mostly based on original papers but might require tuning for specific environments or network scales. Optimizers like Kron may have specific hyperparameters related to the FIM approximation.
Practical Implications:
The findings suggest that developers aiming to scale deep RL agents should prioritize network architectures and optimization strategies that explicitly promote stable gradient propagation. Simply increasing network size with standard architectures and first-order optimizers is likely to fail. Integrating architectural elements like multi-skip connections and employing curvature-aware optimizers like Kron are practical steps towards unlocking the potential of larger models in non-stationary RL settings. The trade-off is increased computational complexity and potentially longer training times associated with second-order methods, which needs to be balanced against the performance gains and improved stability.
Limitations:
The authors acknowledge limitations, primarily related to computational resources constraining the exploration of even larger architectures. While the interventions show consistent improvements, further scaling beyond the tested sizes remains an open research question. The increased computational cost of second-order optimizers is also a practical limitation, although vectorized environments help mitigate this.
In summary, this paper provides a practical analysis of deep RL scaling failures, identifies gradient instability under non-stationarity as a key culprit, and offers concrete architectural and optimization-based solutions that significantly improve the ability to train larger, more performant RL agents across diverse tasks and algorithms.