TD-JEPA: Latent-Predictive Zero-Shot RL
- TD-JEPA is an unsupervised reinforcement learning architecture that uses temporal-difference latent-predictive representation learning to enable zero-shot policy optimization.
- It unifies successor feature learning and latent predictive modeling by training distinct state and task encoders with a policy-conditioned multi-step predictor using offline data.
- Empirical evaluations on DMC and OGBench demonstrate that TD-JEPA improves performance in zero-shot settings by effectively generalizing to arbitrary linear reward functions.
TD-JEPA (Temporal-Difference Joint-Embedding Predictive Architecture) is an unsupervised reinforcement learning framework that leverages latent-predictive representation learning via a temporal-difference (TD) loss. TD-JEPA addresses key limitations of previous joint-embedding predictive approaches by enabling multi-step, policy-conditioned, and off-policy latent prediction using reward-free, offline data. Through explicit training of state and task encoders, a policy-conditioned multi-step predictor, and a parameterized family of latent policies, TD-JEPA enables direct zero-shot optimization for arbitrary linear reward functions at test time. The architecture unifies principles of successor feature learning and latent-predictive modeling, offering theoretical and empirical advances for scalable zero-shot reinforcement learning from high-dimensional input spaces and large offline datasets (Bagatella et al., 1 Oct 2025).
1. Motivation: Latent-Predictive Representation Learning
Latent-predictive approaches such as Joint-Embedding Predictive Architectures (JEPAs) train an encoder and a predictor to satisfy for transitions . These models operate entirely in a learned latent space rather than reconstructing observations, which circumvents the challenges inherent in modeling high-dimensional input (e.g., pixel-level data). The signal for learning is derived exclusively from raw state transitions, allowing for efficient exploitation of vast offline datasets without rewards.
The strengths of JEPA-based latent prediction include:
- Utilization of unsupervised, reward-free transitions.
- Avoidance of reconstructive losses over high-dimensional observations.
- Potential to encode transition dynamics relevant to value estimation and planning.
However, existing JEPA methods suffer from three notable limitations:
- Restriction to single-step latent prediction, which fails to capture the multi-step dynamics necessary for value or successor feature approximation.
- Dependence on on-policy rollouts for positive sample construction, hindering application to offline or off-policy datasets.
- Focus on single-task learning, which inhibits generalization to multiple reward functions required for zero-shot RL scenarios.
TD-JEPA addresses these gaps by introducing a latent-predictive loss that is multi-step, policy-conditioned, and applicable in off-policy, multitask settings, with learned representations simultaneously capturing both state and task information (Bagatella et al., 1 Oct 2025).
2. Theoretical Principles: Latent Spaces, Successor Features, and Non-Collapse
TD-JEPA is built on the assumption of a reward-free Markov Decision Process (MDP) . Two primary encoders are learned:
- (state encoder)
- (task encoder)
A family of latent-conditioned policies is parameterized, where each plays the role of a task descriptor.
The TD-JEPA loss is grounded in the approximation of the discounted sum of future -features under policy . Let the K-step Monte-Carlo target be:
with the predictor . The TD-JEPA (temporal-difference) loss, which enables off-policy training with bootstrapping, is defined as:
where is the stop-gradient operator.
Under standard linearity and covariance-preserving conditions, recovers successor features in the -space, and the covariance of , is theoretically constant through fast predictor updates, preventing trivial collapse (Bagatella et al., 1 Oct 2025).
3. Algorithmic Structure and Network Architecture
TD-JEPA simultaneously trains four parameterized modules using a reward-free, offline transition dataset :
- State encoder
- Task encoder
- Predictor
- Latent policies
The combined loss function is:
with
- : TD-JEPA losses in both directions (asymmetric),
- Regularizer for near-orthonormal embeddings,
- Actor loss aligning to maximize the successor feature in the direction.
Each component typically comprises shallow MLPs (2–4 layers for encoders, 2–3 for predictors; CNNs for pixel input), with latent features normalized to unit length. Latent-conditioned policies are Gaussian MLPs parameterized by (Bagatella et al., 1 Oct 2025).
Core training pseudocode
1 2 3 4 5 6 7 8 9 10 11 12 |
for each gradient step: sample (s, a, s') ~ D and z ~ Uniform(Z) a_prime = π(φ_s(s'), z).sample() φsp = φ_t_target(s') pred_next = T_target(φ_s_target(s'), a_prime, z) y = φsp + γ * pred_next pred = T(φ_s(s), a, z) loss_td_s = ‖pred - y‖² # symmetric loss, orthonormal reg, and actor computed analogously total_loss = loss_td_s + loss_td_t + λs * loss_reg_s + λt * loss_reg_t + loss_actor optimizer.step() update_target_networks() |
4. Zero-Shot Reinforcement Learning Procedures
Following unsupervised pretraining, defines a space of linear reward functions:
Given a novel reward function , the process for zero-shot RL is:
- Collect a small batch from the target environment.
- Fit via least squares:
where is the matrix of vectors, and the vector of corresponding rewards.
- Deploy the policy that was pretrained to maximize .
This approach requires no further environment interaction or fine-tuning, enabling true plug-and-play zero-shot RL for linear reward specifications (Bagatella et al., 1 Oct 2025).
5. Empirical Benchmarks and Analysis
TD-JEPA was evaluated on ExoRL/DMC for locomotion and navigation (with proprioceptive and pixel/RGB inputs) and on OGBench for navigation and manipulation (using both high- and low-coverage offline data). Main evaluation metrics include expected return (DMC) and success rate (OGBench), both in the zero-shot regime, averaged over multiple tasks.
Key empirical findings:
- On DMC_RGB, TD-JEPA achieves 628.8 average return versus the next-best 582.4.
- On OGBench_RGB, TD-JEPA attains 41.34% success, outperforming or matching baselines such as Laplace, FB, HILP, BYOL*, BYOL-γ, RLDP, and ICVF.
- TD-JEPA demonstrates strong performance across both proprioceptive-based and pixel-based tasks.
Ablation studies reveal:
- Multi-step, policy-conditioned (TD-JEPA) objectives substantially outperform one-step or non-policy-conditioned variants.
- Asymmetric encoder architectures (separating and ) yield consistent but modest improvements over symmetric setups.
- Reusing frozen for fast adaptation in downstream tasks significantly accelerates fine-tuning (Bagatella et al., 1 Oct 2025).
6. Discussion, Limitations, and Research Directions
TD-JEPA’s multi-step, policy-conditioned objective is directly related to successor feature learning, naturally supporting zero-shot policy evaluation and control. The architectural separation between state and task encoders allows to focus on capturing dynamics relevant to control, while forms a latent task basis for reward generalization. Orthonormal regularization is crucial to avoid encoder collapse and to promote distributed latent representations.
Current theoretical results are established under assumptions such as symmetric kernels, uniform state distributions, and identity feature covariances. Generalizing the analysis to more realistic, non-uniform, and off-policy sampling regimes remains an open question.
Zero-shot performance degrades on very low-coverage, expert-only datasets (e.g., OGBench antmaze-ls/me). While improved offline corrections (such as flow-based behavior cloning) offer partial remedies, robust solutions for severe data coverage problems are needed. Scaling to continuous action spaces beyond Gaussian policies and validation with real-robot datasets are cited as promising avenues for further work.
TD-JEPA establishes a unified approach to unsupervised latent-predictive representation learning and successor-feature-based zero-shot RL that is multi-step, policy-conditioned, and fully off-policy, enabling robust state and task encoders that avoid collapse, capture long-horizon dynamics, and support plug-and-play zero-shot deployment for any linear reward in the latent task space (Bagatella et al., 1 Oct 2025).