- The paper introduces a causal framework for deep generative models that decouples latent factors for enhanced interpretability.
- It employs techniques from linear factor analysis and nonlinear mappings to address identifiability challenges in latent representations.
- The approach facilitates targeted interventions and scalable estimation for high-dimensional data, with potential integration into advanced foundation models.
Deep generative models (DGMs) excel at learning complex data representations, leading to state-of-the-art performance in various domains like image and text generation. However, their internal "black box" nature makes these representations difficult to interpret and understand. Causal Representation Learning (CRL) emerges as a field aiming to bridge this gap by building interpretable DGMs from the ground up. CRL integrates principles from deep learning, statistics, and causality to develop flexible, interpretable, and transferable generative models. The core idea is to learn latent representations that correspond to causally independent factors of variation, allowing for a deeper understanding of the data generating process and enabling targeted interventions in the latent space.
CRL draws heavily on concepts from classical latent variable models like factor analysis. In linear factor analysis, observed high-dimensional data xi∈RD is modeled as a linear combination of lower-dimensional latent factors zi∈RK plus noise: xi=Bzi+εi. While this model provides a latent representation zi and interpretable loadings B, it suffers from non-identifiability, particularly rotational invariance when zi is Gaussian. Classical methods resolve this by imposing restrictions on the model, such as constraining the loadings matrix to be triangular, assuming sparsity or "anchor features" (observed variables dependent on only one latent factor), or assuming non-Gaussian/heavy-tailed distributions for the latent factors. These restrictions can be interpreted through the lens of graphical models, relating assumptions on B or zi distribution to structural properties of an underlying graph.
In CRL, the focus shifts to a more general nonlinear factor model: xi=f(zi)+εi, where f is a potentially complex nonlinear function (often parameterized by a deep neural network) and zi follows a latent causal model, frequently assumed to be a linear structural equation model (SEM) zi=ATzi+D1/2νi. The matrix A encodes the directed acyclic graph (DAG) structure of the latent variables, representing causal relationships. The goal of CRL is to recover f and the latent causal model (including A) from observed data xi.
A key challenge in this nonlinear setting, as in linear factor analysis, is identifiability. The model is generally non-identifiable; multiple pairs $(f, \zpr)$ can generate the same observed data distribution $p(\bx_i)$. Identifiability is crucial for reliably interpreting the learned factors and their relationships. CRL research explores conditions under which the model parameters $(f, \zpr, A)$ are identifiable, at least up to meaningful transformations (like permutations and element-wise scaling, often referred to as "disentanglement").
Identifiability in CRL is typically achieved by imposing constraints on the function class F for f or the distribution family P for $\zpr$.
Assumptions on the mixing function f:
- Sparsity and Anchor Features: Similar to linear factor analysis, assuming f is sparse in a structured way (e.g., each observed variable depends on only a subset of latent factors, especially "anchor features" dependent on only one factor) can help identify the latent factors up to element-wise transforms and permutations [moran2022identifiable].
- Subset Condition: A related idea for discrete latents requiring that the set of observed variables influenced by one latent is not a subset of those influenced by another [kivva2021learning].
- Independent Mechanisms Analysis (IMA): Assuming that the mixing function f has a Jacobian with orthogonal columns, inspired by the principle of independent mechanisms, can constrain the set of ambiguous transformations for f [gresele2021independent].
Assumptions on the latent factor distribution $\zpr$:
- Auxiliary Information: Observing auxiliary variables ui correlated with zi can help pin down $\zpr$ and f, often identifying them up to linear or element-wise transforms [khemakhem2020variational].
- Mixture Priors: Assuming $\zpr$ is a mixture distribution (e.g., Gaussian mixture) introduces structure that can break rotational invariance and aid identifiability up to certain transformations [kivva2022identifiability].
- Temporal Dependence: Exploiting temporal structure in time-series data can restrict the set of possible latent transformations, enabling identifiability [halva2021disentangling, lachapelle2022disentanglement].
- Invariance/Multiple Environments: Leveraging data from multiple environments or utilizing data augmentations can provide constraints based on invariant properties of the latent factors, leading to identification [von2021self, eastwood2023self, yao2024multiview].
Crucially, identifying the causal graph $\latgr$ among latent variables requires richer data, typically interventions. Observational data alone is generally insufficient to identify the latent causal graph. Recent theoretical work shows that identifying $\latgr$ often requires interventions on each latent variable.
- In the linear mixing case with linear SEMs on $\bz_i$, single-node perfect interventions on all latent variables can identify the latent causal graph [squires2023linear].
- For nonlinear mixing f and linear SEMs on Gaussian latents, single-node perfect interventions on all latents can identify both f (up to element-wise transforms) and the latent causal graph [buchholz2023learning]. Relaxing to soft interventions requires additional sparsity assumptions on the graph [zhang2023interventions].
- Moving beyond linear SEMs or Gaussianity, nonparametric identifiability of the latent causal graph from unknown interventions is possible under certain conditions, often requiring multiple interventions per latent node and utilizing properties of score functions [varici2024general, von2023nonparametric]. Another approach leverages an "independent support" condition on the latent variables [wang2024desiderata, ahuja2023interventional].
Practical implementation of CRL methods often builds upon deep generative models like Variational Autoencoders (VAEs) [kingma2013auto, rezende2014vae]. A standard VAE learns a latent representation, but it is typically not identifiable. To enforce identifiability, methods modify the VAE objective or architecture. For instance, the sparse VAE [moran2022identifiable] introduces a binary masking variable W with a sparsity-inducing prior, making the mapping from latent factors to observed features sparse, analogous to sparse factor loadings. The CausalDiscrepancy VAE [zhang2023interventions] uses interventional data to guide the learning process, combining the VAE ELBO with a maximum mean discrepancy (MMD) term to match distributions of generated and real interventional data and a penalty on the latent causal graph A. Score-based methods learn the inverse mapping h=f−1 by minimizing the difference in score functions between observational and interventional data distributions, which is theoretically linked to the Jacobian of h [varici2024general].
Despite recent progress, CRL is an active and developing field with significant open questions. Key challenges include extending methods beyond specific types of interventions to more general environments, developing robust and scalable estimation algorithms for high-dimensional nonlinear models, establishing rigorous finite-sample guarantees and methods for uncertainty quantification, and integrating CRL concepts with powerful foundation models like LLMs to enhance their interpretability and enable causal reasoning capabilities. Research is ongoing to address these issues and fully realize the potential of interpretable deep generative models grounded in causality.