Interventional Variational Graph Auto-Encoder
- iVGAE is a variational graph auto-encoder that integrates intervention (do-operator) and counterfactual queries within a causal inference framework using a known DAG.
- The architecture employs a message-passing GNN to enforce conditional independence, enabling precise modeling of observational, interventional, and counterfactual distributions.
- The model uses variational inference with a shallow encoder and importance-weighted ELBO, yielding competitive empirical performance in causal tasks and fairness auditing.
The Interventional Variational Graph Auto-Encoder (iVGAE), also referred to as VACA, is a class of variational graph autoencoders designed for causal inference in the absence of hidden confounders, incorporating both interventions (do-operator) and counterfactual queries. The model assumes access to a known causal directed acyclic graph (DAG) over observed variables, and constructs a latent-variable graphical model whose encoding and decoding architectures precisely enforce the conditional independence structure of the underlying causal graph. This framework enables principled, nonparametric approximation of both interventional and counterfactual distributions from observational data and the known DAG without requiring parametric assumptions about the structural equations or latent confounder distributions (Sanchez-Martin et al., 2021).
1. Core Architecture and Graphical Structure
iVGAE operates over a causal DAG , where indexes the observed (possibly vector-valued) variables . Each node is associated with an independent latent variable , typically with a standard Normal prior . The overall joint model is: where the adjacency matrix encodes the known graph structure. The decoder is realized as a message-passing GNN with hidden layers and is designed such that each node's output depends only on the latent variables of its ancestors in . Specifically, for each node , the GNN propagates messages so that after passes, the output encodes all information sent by ancestral . The likelihood for each node is factorized: where are GNN-readout parameters, supporting both Gaussian (for continuous) and categorical (for discrete) data types. This strictly enforced conditional-independence structure mirrors Pearl's causal factorization: $p(X \mid Z, A) = \prod_i p(X_i \mid Z_{\anc(i)})$
2. Variational Inference and Posterior Structure
The variational posterior is constructed to closely reflect the causal structure of the DAG. It uses a "shallow" message-passing GNN (specifically, a single aggregation layer), where each node's encoder aggregates observations only from its parents: $q_{\phi}(Z \mid X, A) = \prod_{i=1}^d q_{\phi_i}(Z_i \mid x_{\pa(i)})$ with $q_{\phi_i}(Z_i \mid x_{\pa(i)}) = \mathcal N(Z_i; \mu_i(h_i^1), \sigma_i^2(h_i^1))$, where is generated by parent-only message passing. Multi-hop propagation is unnecessary, as, by the abduction step in SCMs, the exogenous noise is conditionally dependent only on and its immediate parents $x_{\pa(i)}$.
3. Objective Functions and Causal Queries
Three key variational objectives are implemented:
a. Observational Evidence Lower Bound (ELBO):
Practically, optimization is performed using a multi-sample importance-weighted ELBO (IWAE) with latent samples per data point.
b. Interventional ELBO:
$\mathcal L_{\text{int}}(a) = \mathbb E_{q_{\phi}(Z \mid X, \do(a))}\left[ \sum_{i \notin I} \log p_{\theta_i}(X_i \mid Z, A_{\do(a)}) \right] - \sum_{i \notin I} \mathrm{KL}\left[q_{\phi_i}(Z_i \mid X, A_{\do(a)})\Vert p(Z_i)\right]$
Here, an intervention $\do(X_I = a)$ modifies the adjacency matrix by severing all parental edges into intervened nodes , and the encoder is rerun using these clamped values.
c. Counterfactual Objective:
For a factual , abduction is performed to sample , the action is performed ($\do(a)$), the encoder and decoder are run with the modified inputs and adjacency, and counterfactuals are sampled or predicted via: $p_{\theta}(X \mid X^F, \do(a)) \approx \int p_{\theta}(X \mid Z, A_{\do(a)})\, q_{\phi}(Z \mid X^F, A)\, dZ$
4. Causal Inference Workflow: Abduction–Action–Prediction
Causal queries are addressed in three phases:
- Abduction: The encoder is run once on the observed and the original adjacency to obtain the posterior over latents.
- Action: The modified adjacency $A_{\do(a)}$ is produced by removing all incoming edges to the intervened nodes and clamping those .
- Prediction: Samples or means of under the action posterior are decoded to predict the counterfactual under the intervention.
This procedure exactly mimics the structural causal steps of abduction, action, and prediction as formalized in the SCM literature.
5. Training Regime and Computational Properties
Training jointly optimizes a linear blend of observational, interventional, and optionally counterfactual ELBOs: Minibatch SGD is used, with each iteration executing:
- Minibatch data sampling.
- Encoder forward pass; samples drawn for variational latents.
- Decoder GNN pass and ELBO computation.
- (Optional) Intervention sampling and repeated pass with modified data/graph.
- Monte Carlo gradient computation and parameter updates.
Complexity per minibatch is , with efficient GPU-based message passing supported via libraries such as PyTorch Geometric.
6. Empirical Results and Comparison
Experiments span several synthetic SCMs (“collider,” “triangle,” “chain,” “M-graph”) and semi-synthetic finance and demographic datasets (“Loan,” “Adult”). Comparative baselines include MultiCVAE (independent node-wise CVAEs) and CAREFL (causal autoregressive normalizing flows). Evaluation metrics comprise Observational and Interventional Maximum Mean Discrepancy (MMD), mean-error of mean (MeanE), mean-error of standard deviation (StdE), Counterfactual MSE, and standard error (SSE).
| Model | Obs MMD | Int MMD | MeanE | StdE | CF MSE | CF SSE |
|---|---|---|---|---|---|---|
| MultiCVAE | 30.4±8.2 | 44.7±12.3 | 13.3±4.8 | 46.6±2.4 | 87.4±3.6 | 65.2±2.8 |
| CAREFL | 9.3±1.5 | 4.9±0.5 | 0.35±0.08 | 81.9±1.8 | 8.1±0.6 | 7.8±0.6 |
| iVGAE (VACA) | 1.5±0.7 | 1.6±0.4 | 0.75±0.31 | 42.0±0.3 | 9.9±0.7 | 7.1±0.4 |
iVGAE outperforms these baselines in observational and interventional MMD, closely reproduces moments of interventional distributions, and uniquely recovers the full variance structure owing to its architecture-induced factorization. Counterfactual error metrics are competitive, attesting to reliable counterfactual estimation.
A notable practical use-case is counterfactual fairness auditing and fair classifier learning, demonstrated on the German-Credit dataset: generating $\do$ samples enables auditing for fairness, and imposing fairness