Papers
Topics
Authors
Recent
Search
2000 character limit reached

State-Conditional VQ-VAE

Updated 26 February 2026
  • The paper introduces state-conditional VQ-VAE, which extends standard VQ-VAE by leveraging discrete latent codes and conditional autoregressive modeling for improved multi-modal trajectory generation.
  • The architecture employs an encoder for quantizing expert trajectories, a decoder with a differentiable quadratic program, and a conditional PixelCNN to model complex state-dependent behaviors.
  • Empirical results demonstrate reduced collision rates and enhanced diversity in autonomous driving scenarios by accurately capturing distinct behavioral modes compared to CVAE baselines.

State-conditional Vector-Quantized Variational Autoencoders (state-conditional VQ-VAE) extend the VQ-VAE framework to enable multi-modal conditional generation, particularly suited for tasks exhibiting complex, multi-modal mapping from system state to actions or trajectories. Unlike models with continuous latent priors such as the Conditional Variational Autoencoder (CVAE), state-conditional VQ-VAE leverages discrete latent spaces learned from expert demonstration data and a conditional autoregressive model over codes, yielding empirically superior coverage of distinct behavioral modes in trajectory generation and lower error rates in safety-critical domains such as autonomous driving (Idoko et al., 2024).

1. Formal Architecture

The state-conditional VQ-VAE comprises three principal components: an encoder mapping demonstration data into discrete code indices via quantization; a decoder reconstructing trajectories from code embeddings and a differentiable quadratic program (QP); and a conditional PixelCNN that models the distribution of code indices as a function of external state.

  • Encoder fef_e: Inputs a trajectory τeR2×T\tau_e \in \mathbb{R}^{2\times T} (e.g., sequences of (x,y)(x, y) positions for T=100T=100) and outputs ll code vectors ZeRl×DZ_e \in \mathbb{R}^{l\times D} via an MLP (e.g., [2T512256lD][2T \to 512 \to 256 \to l\cdot D]). Each code vector is quantized by nearest-neighbor search in a codebook of KK embeddings [e0,...,eK1][e_0, ..., e_{K-1}].
  • Quantization: For each code slot ii, zq,i=eriz_{q,i} = e_{r_i}, where ri=argminjze,iej2r_i = \arg\min_{j} \|z_{e,i} - e_j\|^2, ze,iRDz_{e,i} \in \mathbb{R}^D. This quantization is non-differentiable; gradients use a straight-through estimator.
  • Decoder fdf_d (plus QP): The quantized codes ZqZ_q are flattened and passed through an MLP (e.g., [lD2565122N][l\cdot D \to 256 \to 512 \to 2N]) to produce trajectory setpoints pR2×Np \in \mathbb{R}^{2\times N} (e.g., forward-velocity and lateral offset). The final output trajectory τ^\hat\tau is computed by solving a differentiable QP:

ξ=argminξ12ξQξ+q(p)ξsubject toAξ=b\xi^\star = \arg\min_{\xi} \frac{1}{2} \xi^\top Q\xi + q(p)^\top \xi \quad \text{subject to} \quad A\xi = b

The resulting solution ξ\xi^* is mapped back to physical coordinates.

  • Conditional PixelCNN: At inference, the codes are sampled auto-regressively from p(hs)p(h|s), where hh is the sequence of code indices and sR55s \in \mathbb{R}^{55} encodes the current system state. The PixelCNN (10 masked convolutional layers, 64 channels) conditions on state embeddings and previously drawn indices to capture complex multi-modal dependencies.

2. Discrete Latent Embedding and Codebook Mechanics

The distinctive feature of VQ-VAE is its use of a learned discrete codebook for latent representation. For codebook size KK (e.g., 512) and code dimension DD (e.g., 64), the encoder output for each slot is quantized to the nearest codebook vector. This technique enables the model to capture discrete, interpretable latent modes—a property especially advantageous in scenarios with strong multi-modality (e.g., alternative passing maneuvers in vehicle trajectory planning).

The codebook is updated during training to minimize the codebook (“embedding”) and commitment losses, optimizing both encoder outputs’ proximity to selected codes and codebook stability.

3. Loss Functions and Optimization

The model is trained to jointly minimize:

  • Reconstruction loss:

rec=τ^τe22\ell_{\mathrm{rec}} = \|\hat{\tau} - \tau_e\|_2^2

  • Codebook loss:

codebook=sg[Ze]EZq22\ell_{\mathrm{codebook}} = \| \mathrm{sg}[Z_e] - E_{Z_q} \|_2^2

Encourages codebook vectors to approach the encoder outputs.

  • Commitment loss:

commit=Zesg[EZq]22\ell_{\mathrm{commit}} = \| Z_e - \mathrm{sg}[E_{Z_q}] \|_2^2

Encourages encoder outputs to match chosen codes.

Combined objective:

LVQVAE=rec+αcommit+βcodebook\mathcal{L}_\mathrm{VQ-VAE} = \ell_\mathrm{rec} + \alpha\,\ell_\mathrm{commit} + \beta\,\ell_\mathrm{codebook}

Typical weighting: α=1.0,β=0.25\alpha=1.0, \beta=0.25 (Idoko et al., 2024). The quantization bottleneck is differentiated via the straight-through estimator, and the gradients propagate through the differentiable QP layer.

4. State-Conditional Sampling and Generation

During inference, τe\tau_e is not available. Instead, given a current system state ss, sampling involves:

  1. The conditional PixelCNN generates a code index sequence h=(r1,...,rl)h = (r_1, ..., r_l) by autoregressively sampling each rir_i from p(ris,r<i)p(r_i | s, r_{<i}).
  2. Each code index rir_i is mapped to its codebook embedding, forming ZqZ_q.
  3. The decoder plus QP reconstructs the output trajectory τ^=fd(Zq)\hat{\tau} = f_d(Z_q).

This indirect conditioning on ss—via code indices sampled from p(hs)p(h|s)—enables the model to represent multi-modal distributions over trajectories for a fixed state. In contrast, direct continuous latent conditioning (e.g., in CVAEs) often yields unimodal reconstructions and mode averaging due to Gaussian priors.

5. Advantages over Conditional VAEs

Traditional CVAEs posit a Gaussian prior p(z)=N(0,I)p(z) = \mathcal{N}(0, I) over latents. In domains such as autonomous driving, p(zs)p(z|s) may be highly multi-modal: certain road states admit fundamentally distinct trajectories (e.g., overtaking options on either side). CVAEs tend to suffer mode collapse, as the Gaussian prior encourages unimodality or "mean" behaviors.

By discretizing the latent representation, VQ-VAE facilitates allocation of separate codebook entries to distinct modes, and, with conditional PixelCNN, p(hs)p(h|s) learns distributions assigning high probability to each relevant mode. Empirically, VQ-VAE with PixelCNN exhibits sharply multi-modal marginals in key trajectory features (velocity, lateral offset), substantially outperforming CVAE baselines both in diversity and safety-critical metrics (up to 12× reduction in collision-rate in dense scenarios) (Idoko et al., 2024).

6. Training Procedure and Dataflow

Training employs demonstration data pairs (τe,s)(\tau_e, s), with the following procedure:

  • Forward pass: encode the demonstration trajectory, quantize, decode, and reconstruct the trajectory.
  • Compute losses: reconstruction, codebook, and commitment losses as defined above.
  • Backpropagate gradients: through the QP layer (analytic gradients) and over the quantization with the straight-through estimator.
  • Update codebook vectors to minimize embedding loss.
  • Separately, fit the conditional PixelCNN to predict codes hh given ss, leveraging the codes inferred from demonstration trajectories.

At inference, the model samples from the learned code distribution conditioned on state, decodes, and applies an optional differentiable, optimization-based safety filter to enforce collision constraints.

7. Empirical Impact and Application Contexts

State-conditional VQ-VAE, as evaluated in (Idoko et al., 2024), demonstrates:

  • More realistic and diverse trajectory generation in autonomous driving scenarios, especially under dense, multi-modal conditions.
  • Superior performance in safety-critical metrics, markedly reducing collision rates compared to Gaussian prior CVAE baselines.
  • Effective modeling of distinct behavioral homotopies without degrading task-relevant effectiveness such as average speed.

A plausible implication is that this architecture is broadly applicable to any domain where the mapping from state to output is multi-modal and safety or diversity are paramount. The separation of representation learning (via VQ-VAE) and conditional multi-modal sampling (via PixelCNN) distinguishes this approach from prior generative models relying on continuous priors.


References:

  • "Learning Sampling Distribution and Safety Filter for Autonomous Driving with VQ-VAE and Differentiable Optimization" (Idoko et al., 2024)
Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to State-Conditional VQ-VAE.