Latent Plan Transformer
- Latent Plan Transformer is a model that introduces a continuous latent plan variable to enable credit assignment and trajectory stitching without step‐wise rewards.
- It employs a causal Transformer-based trajectory generator conditioned on latent plans, trained via maximum likelihood with MCMC sampling to achieve temporal consistency.
- Empirical evaluations demonstrate that LPT outperforms baseline models on offline RL benchmarks by ensuring cohesive long-horizon planning and effective trajectory composition.
The Latent Plan Transformer (LPT) is a generative model for trajectory abstraction and planning, specifically designed to address offline reinforcement learning (RL) settings where only full-trajectory returns are available and step-wise reward signals are absent. LPT introduces a latent continuous variable, termed the "plan," which enables the enforcement of temporal consistency across entire episodes, credit assignment over long horizons, and compositional planning via trajectory stitching. Its distinguishing algorithmic feature is planning as latent space inference, realized by a Transformer-based trajectory generator conditioned on the plan variable, with learning and inference achieved via maximum likelihood estimation and Markov Chain Monte Carlo (MCMC) sampling in the latent space (Kong et al., 7 Feb 2024).
1. Problem Formulation and Motivation
LPT is motivated by the challenge of long-term planning with offline RL datasets , where each is a trajectory of state-action pairs and is the total return. In this setting:
- Credit assignment: Effective association of sparse/delayed rewards to temporally distant actions is difficult in the absence of step-wise reward signals.
- Trajectory stitching: Construction of new, high-return trajectories from observed suboptimal fragments.
- Temporal consistency: Mitigation against policy drift in autoregressive models that operate on finite context, conditioned solely on past states and a single summary return.
LPT addresses these issues by introducing a latent "plan" variable that generates trajectories and predicts scalar returns, facilitating episode-level coherence and scalable planning as latent variable inference (Kong et al., 7 Feb 2024).
2. Probabilistic Model and Inference
LPT defines a joint generative model: where
- : Plan prior. Implicit, with mapped to via a neural network (U-Net or MLP).
- : Trajectory generator. An autoregressive, causal Transformer operating over finite context ; each token predicts .
- : Return predictor. Gaussian likelihood , with an MLP and fixed.
The evidence lower bound (ELBO) for maximum likelihood training is: with approximate posterior . If , the bound is tight. Marginal likelihood involves integrating out : (Kong et al., 7 Feb 2024).
3. Model Architecture and Training
LPT's architecture comprises:
- Plan prior: Samples , maps to , where is either a U-Net or MLP, representing an implicit prior .
- Trajectory generator: Stack of Transformer blocks with causal self-attention over the past tokens and cross-attention from at each token position. At each timestep , outputs a Gaussian policy .
- Return head: MLP computing the mean for the Gaussian return predictor.
Training algorithm:
- LPT is optimized via offline maximum likelihood, leveraging (approximate) posterior sampling in the latent space.
- For each training example, is approximated via Langevin dynamics on , where transitions:
are performed, with gradients over the joint log-probability for the trajectory and return.
- Model parameters are updated by gradient ascent using empirical averages over the sampled (Kong et al., 7 Feb 2024).
4. Planning as Latent Space Inference
At test time, LPT realizes planning as inference by conditioning on a desired return . The plan is inferred as the mode of via MCMC in the latent space: After steps, set . An episode is then generated by rolling out the autoregressive policy:
This approach enables specification of arbitrary target returns, framing planning as finding a plan most compatible with the desired outcome in the learned latent space (Kong et al., 7 Feb 2024).
5. Empirical Evaluation and Results
LPT is benchmarked on a range of environments:
| Domain | Tasks/Subsets | Characteristics |
|---|---|---|
| Gym-Mujoco | HalfCheetah, Hopper, Walker2D (medium, replay), AntMaze (umaze, diverse) | Continuous control, dense/sparse reward |
| Maze2D | umaze, medium, large | Navigation, sparse reward |
| Connect Four | vs. stochastic opponent | Board game, adversarial |
Baselines: CQL, Decision Transformer (DT), Q-learning Decision Transformer (QDT), Online DT (ODT), ESPER.
Metrics: Average return standard deviation over 5 seeds.
Key statistical findings:
- On Gym-Mujoco with only final return supervision, LPT outperforms DT and QDT, at times matching or exceeding CQL, which has access to step-wise rewards.
- On Maze2D and AntMaze, LPT yields – improvement over DT by stitching suboptimal trajectory fragments into near-optimal full trajectories.
- On Connect Four, LPT achieves performance () matching SOTA ESPER (), substantially outperforming DT () (Kong et al., 7 Feb 2024).
Qualitative insights:
- Posterior gradients in integrate reward information from sub-trajectories, automating credit assignment.
- t-SNE visualizations indicate the latent plan space allows interpolation between trajectories, capturing novel, high-return behaviors via trajectory stitching.
- During execution, the latent is fixed, but the policy adapts to environment stochasticity, limiting overfitting to dataset contingencies.
6. Strengths, Limitations, and Future Prospects
Strengths:
- Enforces temporal consistency over entire episodes without explicit step-wise reward or return-to-go conditioning.
- Posterior sampling in latent space creates abstractions aggregating information across finite-context fragments.
- Planning as inference (MCMC on ) allows for return-conditioned generation without relying on reward-to-go as input.
- Demonstrates competitive or superior empirical results across dense, sparse, and adversarial benchmarks, supporting long-horizon credit assignment and trajectory composition (Kong et al., 7 Feb 2024).
Limitations and open questions:
- MCMC-based latent sampling scales poorly for very long horizons. While persistent chains and fewer steps partially mitigate this, further advances such as amortized inference are desirable.
- The implicit latent prior lacks explicit density modeling; replacing or augmenting it with normalizing flows or energy-based models may improve expressiveness.
- Extending LPT to multi-task or hierarchical RL with discrete/discontinuous returns is an open direction.
- Online continual fine-tuning currently yields limited gains; integrating the LPT posterior machinery with value-based pessimistic objectives remains an open research problem (Kong et al., 7 Feb 2024).