Papers
Topics
Authors
Recent
Search
2000 character limit reached

Wasserstein Policy Optimization

Published 1 May 2025 in cs.LG and cs.AI | (2505.00663v1)

Abstract: We introduce Wasserstein Policy Optimization (WPO), an actor-critic algorithm for reinforcement learning in continuous action spaces. WPO can be derived as an approximation to Wasserstein gradient flow over the space of all policies projected into a finite-dimensional parameter space (e.g., the weights of a neural network), leading to a simple and completely general closed-form update. The resulting algorithm combines many properties of deterministic and classic policy gradient methods. Like deterministic policy gradients, it exploits knowledge of the gradient of the action-value function with respect to the action. Like classic policy gradients, it can be applied to stochastic policies with arbitrary distributions over actions -- without using the reparameterization trick. We show results on the DeepMind Control Suite and a magnetic confinement fusion task which compare favorably with state-of-the-art continuous control methods.

Summary

  • The paper introduces Wasserstein Policy Optimization (WPO), a novel actor-critic reinforcement learning algorithm for continuous action spaces based on Wasserstein gradient flows.
  • The core WPO update uses the action-value gradient ( abla_a Q) to guide stochastic policy updates, applying approximations like variance rescaling for parametric policies and incorporating KL regularization.
  • Experimental evaluation shows WPO achieves robust performance, scales effectively to high-dimensional action spaces, and performs well on challenging continuous control tasks like magnetic fusion.

Wasserstein Policy Optimization (WPO) is a novel actor-critic reinforcement learning algorithm designed for environments with continuous action spaces. It combines advantages from both classic stochastic policy gradient methods and deterministic policy gradient (DPG) methods by leveraging the theory of Wasserstein gradient flows. The core idea is to derive a policy update that minimizes a functional (the negative expected return) in the space of probability distributions using the Wasserstein metric, and then project this nonparametric flow onto a finite-dimensional parametric space (e.g., neural network weights).

Traditional policy gradient methods for stochastic policies estimate the gradient of the expected return by weighting the log-likelihood gradient of sampled actions by a scalar value (like the Q-value or advantage). DPG methods, on the other hand, operate on deterministic policies and use the gradient of the Q-function with respect to the action. WPO bridges these by deriving an update for stochastic policies that also depends on the gradient of the Q-function with respect to the action, specifically aQπ(s,a)\nabla_{\mathbf{a}} Q^\pi(s, \mathbf{a}).

The theoretical foundation of WPO lies in Wasserstein gradient flows (Sec 2.2). A gradient flow in the space of probability distributions that minimizes a functional J[π]\mathcal{J}[\pi] can be described by a partial differential equation: πt=a(π(aδJδπ))\frac{\partial \pi}{\partial t} = -\nabla_{\mathbf{a}} \cdot \left(\pi \left(-\nabla_{\mathbf{a}} \frac{\delta \mathcal{J}}{\delta \pi}\right)\right). For reinforcement learning, the functional derivative of the expected return J[π]\mathcal{J}[\pi] with respect to the policy π(as)\pi(\mathbf{a}|s) is proportional to the action-value function Qπ(s,a)Q^\pi(s, \mathbf{a}) times the discounted state occupancy measure dπ(s)d^\pi(s) (Sec A.1), i.e., δJδπ(s,a)Qπ(s,a)dπ(s)\frac{\delta \mathcal{J}}{\delta \pi}(s, \mathbf{a}) \propto Q^\pi(s, \mathbf{a}) d^\pi(s). For per-state updates, the dπ(s)d^\pi(s) term arises naturally from sampling states, leaving the update proportional to Qπ(s,a)Q^\pi(s, \mathbf{a}). The optimal transport literature shows that for the 2-Wasserstein metric, the steepest descent flow corresponds to a velocity field proportional to aδJδπ-\nabla_{\mathbf{a}} \frac{\delta \mathcal{J}}{\delta \pi}. Thus, the policy should ideally evolve in the direction of the gradient of the Q-function with respect to the action.

To translate this nonparametric flow into an update for a parametric policy πθ(as)\pi_\theta(\mathbf{a}|s), the paper derives the parameter update Δθ\Delta \theta that best approximates the desired distribution shift in a small time step dtdt by minimizing the KL divergence between the desired distribution after the flow and the parameterically updated distribution. This minimization, using a local approximation of the KL divergence by the Fisher information matrix, leads to the core WPO update rule for parameter θ\theta: Δθ=Fθθ1Eπ[θalogπ(as)aQπ(s,a)]\Delta\theta = \mathcal{F}_{\theta\theta}^{-1} \mathbb{E}_{\pi}\left[\nabla_\theta \nabla_{\mathbf{a}} \mathrm{log}\pi(\mathbf{a}|s) \nabla_{\mathbf{a}} Q^\pi(s,\mathbf{a})\right], where Fθθ\mathcal{F}_{\theta\theta} is the Fisher information matrix of the policy parameters and the expectation is taken over actions sampled from the current policy πθ\pi_\theta. This update uses the gradient of the Q-function aQπ(s,a)\nabla_{\mathbf{a}} Q^\pi(s,\mathbf{a}) to guide the policy update, similar to DPG, but critically, it applies to stochastic policies and involves the gradient of the log-likelihood gradient with respect to the action, averaged over samples from the policy.

For practical implementation with deep neural networks, the exact Fisher information matrix Fθθ\mathcal{F}_{\theta\theta} is computationally expensive. The paper makes a key approximation by focusing on Gaussian policies with diagonal covariance πθ(as)=N(aμθ(s),Σθ(s))\pi_\theta(\mathbf{a}|s) = \mathcal{N}(\mathbf{a}|\boldsymbol{\mu}_\theta(s),\boldsymbol{\Sigma}_\theta(s)), where Σ\boldsymbol{\Sigma} is diagonal. The Fisher information matrix for a Gaussian policy with diagonal covariance is diagonal, with entries related to 1/σi21/\sigma_i^2 for means and 1/σi21/\sigma_i^2 or 2/σi22/\sigma_i^2 for standard deviations. Rather than computing the full FIM and its inverse, WPO uses a variance rescaling heuristic: it scales the gradients of the log-likelihood with respect to the mean and standard deviation parameters by factors proportional to their respective diagonal FIM entries (σi2\sigma_i^2 for μi\mu_i and 12σi2\frac{1}{2}\sigma_i^2 for σi\sigma_i) before backpropagating through the network (Alg 1, line 18). This rescaling helps stabilize training, especially as variance might decrease, preventing the likelihood gradient alogπ\nabla_{\mathbf{a}}\mathrm{log}\pi from blowing up.

Furthermore, like many successful deep RL algorithms, WPO incorporates regularization to prevent overly large policy updates. This is done by adding a penalty proportional to the KL divergence between the current policy and a target policy (e.g., a lagged copy of the current policy) to the optimization objective (Eq. 7, Eq. 8). This can be implemented as a soft constraint (adding a KL term to the loss) or a hard constraint (using a dual optimization approach, similar to MPO). The gradient of this KL penalty is computed conventionally. The overall actor update then combines the approximate Wasserstein gradient from the Q-function and the KL regularization term.

WPO is implemented within an actor-critic framework. The critic Qw(s,a)Q_w(s, \mathbf{a}) is trained using a standard nn-step TD update (Eq. 9) based on sampled transitions from a replay buffer. The target value for the TD update uses a target critic network Qˉwˉ\bar{Q}_{\bar{w}} and samples actions from a target policy network πˉθˉ\bar{\pi}_{\bar{\theta}}. The actor πθ\pi_\theta is updated using the WPO rule derived from the current critic QwQ_w. The algorithm uses an off-policy replay buffer for states but samples actions from the current policy when computing the policy gradient (on-policy actions w.r.t. the current policy parameters, given the sampled state).

The paper evaluates WPO on tasks from the DeepMind Control Suite (DMC) and a challenging magnetic confinement fusion control task. On DMC, WPO demonstrates robust performance across a wide range of tasks, often matching or exceeding the performance of baseline algorithms like DDPG and SAC, which can struggle on certain tasks (Fig. 5, App C.1). WPO shows greater out-of-the-box generalization across environments compared to DDPG and SAC, which are more sensitive to hyperparameter tuning. To test scalability to very high dimensions, WPO is applied to "combined" DMC tasks, controlling multiple replicas of an environment simultaneously (e.g., 5 Humanoid Stand tasks, leading to 105 action dimensions). Results show that WPO learns faster and takes off earlier in training than MPO and DDPG as the number of action dimensions increases (Fig. 6), suggesting an advantage in high-dimensional control. On the magnetic fusion control task (19 action dimensions, 93 state dimensions), WPO achieves performance comparable to MPO, a state-of-the-art method for this specific domain, demonstrating its applicability to complex real-world problems (Fig. 7, App C.3).

Potential practical considerations include numerical stability, especially related to the action-value gradient aQ\nabla_{\mathbf{a}} Q. The paper explores using a squashing function on this gradient (e.g., cube root) derived from the theory of c-Wasserstein distances (App C.2). While not consistently improving performance across all tasks in their ablations, it offers a principled approach to handling potential gradient blow-ups. The choice of neural network activation function (ELU vs. SiLU) also shows minor sensitivity on some tasks. WPO's performance can also be sensitive to the KL regularization weights and the specific method used for critic bootstrapping (e.g., mean or max over target action samples).

In summary, WPO provides a theoretically grounded, practically competitive approach for continuous action reinforcement learning. Its derivation from Wasserstein gradient flows leads to a novel update rule that effectively utilizes the gradient of the Q-function in action space for stochastic policies. Practical implementation requires approximations to the Fisher information and standard regularization techniques. Experimental results highlight WPO's robustness, scalability to higher dimensions, and applicability to real-world control problems.

Paper to Video (Beta)

No one has generated a video about this paper yet.

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 6 tweets with 64 likes about this paper.