3D Causal Variational Autoencoder
- The model introduces a generative framework that leverages temporal structure and interventions to disentangle independent causal factors in high-dimensional 3D visual data.
- It employs an autoencoder combined with a normalizing flow and a dynamic Bayesian network to effectively factorize and infer vector-valued latent variables.
- The approach offers provable identifiability guarantees and demonstrates strong empirical performance on challenging 3D scenes and interventional domains.
A 3D Causal Variational Autoencoder (3D CausalVAE) is a generative framework designed for learning causal representations from sequential high-dimensional visual data, such as rendered image sequences, where the underlying latent causal factors may be both scalar and multidimensional (e.g., 3D positions and 3D rotations). The approach leverages the temporal structure of the data and a record of interventions to identify and disentangle the independent underlying causes—extending previous identifiability results to the setting of vector-valued causal factors. The CITRIS ("Causal Identifiability from Temporal Intervened Sequences") framework exemplifies a 3D CausalVAE by providing theoretical guarantees, a flexible neural implementation, and empirical validation on challenging 3D scenes (Lippe et al., 2022).
1. Generative Model Structure
The model assumes a latent dynamical system comprising causal factors , where each can be multidimensional (enabling, for example, vector-valued 3D rotations). Observed data at each timestep are generated deterministically via a bijective observation function plus observation noise : .
Interventions are specified by a binary vector (with if factor 0 has been intervened upon), and a latent regime variable 1 may confound 2. The process is modeled as a Dynamic Bayesian Network, with parentage stipulated so that each 3 depends on a subset of 4 and its own intervention 5; 6 and 7 generate 8.
The joint density over 9 steps is factorized as:
0
By exploiting the invertibility of 1 (with inverse 2), the model induces a decoder/likelihood and transition prior in the latent space 3. The one-step conditional likelihood is:
4
The transition prior further factorizes over "blocks" of latents assigned to each causal factor and a "junk" block:
5
with block assignments 6, and 7 by convention.
2. Inference, Variational Posterior, and Normalizing Flow
The variational posterior 8 is factored as a product of independent per-latent Gaussian distributions:
9
For enhanced expressivity and disentanglement, the AE+NF (Autoencoder + Normalizing Flow) extension is used. Here, an autoencoder 0 is first trained unrestrictedly; after training, these components are frozen and an invertible normalizing flow 1 maps the autoencoder's embeddings 2 to the latent variables 3. The completed approximate posterior is defined as:
4
with a change-of-variables correction:
5
where 6 and each 7 is a coupling layer, using MAF/affine autoregressive transformations and interleaved normalization and invertible 1x1 convolutions (inspired by Glow).
3. Learning Objective and Block Assignment
Learning is performed by maximizing a variational lower bound (ELBO) on the conditional log-likelihood for each transition 8:
9
The first term drives reconstructions, while the KL divergences align each causal-factor block (and the nuisance block) to the respective transition prior; 0 encourages nuisance information to concentrate in block 0.
A target–classification (TC) loss further encourages 1 to be selectively informative about its intervention target 2 and invariant to others. This is implemented by a learned classifier 3 that predicts 4 from 5, with gradients back-propagated selectively.
Block assignments 6 for latent dimension 7 are parameterized by a categorical variable over 8, implemented with Gumbel-Softmax for sampling during training, and argmax assignment at test time.
4. Identifiability Result
Suppose:
- 9 is invertible,
- the latent process is stationary, first-order Markov, with no instantaneous effects,
- interventions 0 are known, non-deterministic, and not always joint,
- 1, and latent dimension 2 are sufficiently expressive.
Then, maximizing the conditional likelihood 3 subject to maximizing entropy in block 0 provably recovers for each 4 the minimal causal variable of 5 (the component of 6 which responds to intervention 7), up to blockwise invertible transformations.
All equivalent maximizers correspond to assignments that only rearrange intervention-dependent components among blocks, but the entropy penalty on block 0 enforces a unique assignment. Identifiability is thus assured for multidimensional, intervention-targeted, temporal causal factors under the specified assumptions. Two factors that are always, or never, intervened upon jointly are not separable within this framework.
5. Neural Architecture and Training Procedure
The 3D CausalVAE instantiation in CITRIS employs:
- Encoder 8: 4 strided conv layers (stride 2) with 64 channels and 9 kernels, BatchNorm+SiLU, and a final 0 conv, then flattened to yield 1, 2 for each latent 3 via parallel linear heads.
- Decoder 4: Linear layer followed by reshaping and 4 upsampling stages (5 each), each succeeded by a residual block (2 conv layers with BatchNorm+SiLU). Output via a 6 conv and Tanh activation.
- Transition Prior 7: Autoregressive MADE network over 8 latent dimensions, conditioned on 9 and 0, predicting Gaussian mean and scale per block.
- Normalizing Flow 1: 4–6 affine/MAF coupling layers, interleaved with ActNorm and invertible 2 convs.
- Assignment 3: Latent-to-block mapping via Gumbel-Softmax over the 4 latents and 5 blocks.
Training uses Adam (lr 6), batch size 512, 7, 8, 9 for 3D, over 600–1000 epochs.
Pseudocode for one step: 00
6. Empirical Evaluation on 3D Scene Sequences
CITRIS is evaluated on the Temporal-Causal3DIdent dataset:
- Seven causal factors: object 3D position 0, object rotations 1, spotlight rotation 2, object/spotlight/background hues 3, and object shape 4.
- Interventions: Each 5, random resetting.
- Train/test split: 250k training, 10k test frames.
Metrics include blockwise 6 to ground-truth factors (both diagonal 7 and separated 8), Spearman correlation, and "triplet evaluation"—combining blocks from different sequences in latent space and measuring recovery fidelity via a specialized CNN encoder.
Key findings:
| Model | 9 | 0 | Triplet Error |
|---|---|---|---|
| CITRIS-VAE | 1 0.9+ | — | — |
| CITRIS-NF | 2 | 3 | 4 .04 |
| SlowVAE (base.) | — | — | Entangled |
| iVAE* (base.) | — | — | — |
- CITRIS-NF achieves 5 and triplet error 6 (Teapot dataset), outperforming prior approaches such as SlowVAE (which entangles factors) and iVAE* (failing on correlated multidimensional factors like hue and rotation).
On the Interventional Pong domain, CITRIS disentangles five intervened factors (ball position/velocity, paddle positions) with 7–8, and effectively attributes 'score' (non-intervened) to the nuisance block.
7. Generalization and Theoretical-Limitation Analysis
The AE+flow variant allows the autoencoder to be pretrained on heterogeneous observational sources (e.g., blending simulated and real data), and the flow adapted with synthetic interventional data only. Empirical results demonstrate zero-shot generalization to unseen object shapes, with 9 and a moderate triplet error. Performance drops slightly for position and rotation when unseen shape categories present different default axes, but minimal variables are still isolated.
Identifiability requires that for every causal factor, there exist both intervened and non-intervened instances. If two factors are always—or never—jointly intervened, they are not separable (Proposition 3.1). The approach fundamentally relies on observing the intervention targets, though not their realized values. Identifiability is defined up to blockwise invertible transforms, and unrestricted rotation/mixing within blocks is not penalized.
References:
CITRIS: Causal Identifiability from Temporal Intervened Sequences (Lippe et al., 2022)