Papers
Topics
Authors
Recent
Search
2000 character limit reached

Interventional Variational Graph Auto-Encoder

Updated 10 March 2026
  • iVGAE is a variational graph auto-encoder that integrates intervention (do-operator) and counterfactual queries within a causal inference framework using a known DAG.
  • The architecture employs a message-passing GNN to enforce conditional independence, enabling precise modeling of observational, interventional, and counterfactual distributions.
  • The model uses variational inference with a shallow encoder and importance-weighted ELBO, yielding competitive empirical performance in causal tasks and fairness auditing.

The Interventional Variational Graph Auto-Encoder (iVGAE), also referred to as VACA, is a class of variational graph autoencoders designed for causal inference in the absence of hidden confounders, incorporating both interventions (do-operator) and counterfactual queries. The model assumes access to a known causal directed acyclic graph (DAG) over observed variables, and constructs a latent-variable graphical model whose encoding and decoding architectures precisely enforce the conditional independence structure of the underlying causal graph. This framework enables principled, nonparametric approximation of both interventional and counterfactual distributions from observational data and the known DAG without requiring parametric assumptions about the structural equations or latent confounder distributions (Sanchez-Martin et al., 2021).

1. Core Architecture and Graphical Structure

iVGAE operates over a causal DAG G=(V,E)G=(V,E), where VV indexes the dd observed (possibly vector-valued) variables X=(X1,,Xd)X=(X_1,\dots,X_d). Each node ii is associated with an independent latent variable ZiZ_i, typically with a standard Normal prior p(Zi)=N(0,I)p(Z_i)=\mathcal N(0,I). The overall joint model is: pθ(X,ZG)=p(Z)pθ(XZ,A)p_{\theta}(X, Z \mid G) = p(Z)\, p_{\theta}(X \mid Z, A) where the adjacency matrix A{0,1}d×dA \in \{0,1\}^{d \times d} encodes the known graph structure. The decoder pθ(XZ,A)p_{\theta}(X \mid Z, A) is realized as a message-passing GNN with LL hidden layers and is designed such that each node's output depends only on the latent variables of its ancestors in GG. Specifically, for each node ii, the GNN propagates messages so that after L+1diam(G)L+1 \geq \operatorname{diam}(G) passes, the output hi(L+1)h_i^{(L+1)} encodes all information sent by ancestral ZjZ_j. The likelihood for each node is factorized: pθi(Xihanc(i))=p(Xi;ηi)p_{\theta_i}(X_i \mid h_{\mathrm{anc}(i)}) = p(X_i; \eta_i) where ηi\eta_i are GNN-readout parameters, supporting both Gaussian (for continuous) and categorical (for discrete) data types. This strictly enforced conditional-independence structure mirrors Pearl's causal factorization: $p(X \mid Z, A) = \prod_i p(X_i \mid Z_{\anc(i)})$

2. Variational Inference and Posterior Structure

The variational posterior is constructed to closely reflect the causal structure of the DAG. It uses a "shallow" message-passing GNN (specifically, a single aggregation layer), where each node's encoder aggregates observations only from its parents: $q_{\phi}(Z \mid X, A) = \prod_{i=1}^d q_{\phi_i}(Z_i \mid x_{\pa(i)})$ with $q_{\phi_i}(Z_i \mid x_{\pa(i)}) = \mathcal N(Z_i; \mu_i(h_i^1), \sigma_i^2(h_i^1))$, where hi1h_i^1 is generated by parent-only message passing. Multi-hop propagation is unnecessary, as, by the abduction step in SCMs, the exogenous noise ZiZ_i is conditionally dependent only on xix_i and its immediate parents $x_{\pa(i)}$.

3. Objective Functions and Causal Queries

Three key variational objectives are implemented:

a. Observational Evidence Lower Bound (ELBO):

Lobs(θ,ϕ)=Eqϕ(ZX,A)[i=1dlogpθi(XiZ,A)]i=1dKL[qϕi(ZiX,A)p(Zi)]\mathcal L_{\text{obs}}(\theta, \phi) = \mathbb E_{q_{\phi}(Z \mid X, A)}\left[ \sum_{i=1}^d \log p_{\theta_i}(X_i \mid Z, A) \right] - \sum_{i=1}^d \mathrm{KL}\left[q_{\phi_i}(Z_i \mid X, A) \Vert p(Z_i)\right]

Practically, optimization is performed using a multi-sample importance-weighted ELBO (IWAE) with KK latent samples per data point.

b. Interventional ELBO:

$\mathcal L_{\text{int}}(a) = \mathbb E_{q_{\phi}(Z \mid X, \do(a))}\left[ \sum_{i \notin I} \log p_{\theta_i}(X_i \mid Z, A_{\do(a)}) \right] - \sum_{i \notin I} \mathrm{KL}\left[q_{\phi_i}(Z_i \mid X, A_{\do(a)})\Vert p(Z_i)\right]$

Here, an intervention $\do(X_I = a)$ modifies the adjacency matrix by severing all parental edges into intervened nodes II, and the encoder is rerun using these clamped values.

c. Counterfactual Objective:

For a factual XFX^F, abduction is performed to sample Zq(ZXF,A)Z \sim q(Z \mid X^F, A), the action is performed ($\do(a)$), the encoder and decoder are run with the modified inputs and adjacency, and counterfactuals are sampled or predicted via: $p_{\theta}(X \mid X^F, \do(a)) \approx \int p_{\theta}(X \mid Z, A_{\do(a)})\, q_{\phi}(Z \mid X^F, A)\, dZ$

4. Causal Inference Workflow: Abduction–Action–Prediction

Causal queries are addressed in three phases:

  • Abduction: The encoder is run once on the observed XFX^F and the original adjacency AA to obtain the posterior over latents.
  • Action: The modified adjacency $A_{\do(a)}$ is produced by removing all incoming edges to the intervened nodes and clamping those Xi=aiX_i = a_i.
  • Prediction: Samples or means of ZZ under the action posterior are decoded to predict the counterfactual XX under the intervention.

This procedure exactly mimics the structural causal steps of abduction, action, and prediction as formalized in the SCM literature.

5. Training Regime and Computational Properties

Training jointly optimizes a linear blend of observational, interventional, and optionally counterfactual ELBOs: L(θ,ϕ)=Lobs(θ,ϕ)+λintLint(θ,ϕ)+λcfLcf(θ,ϕ)\mathcal{L}(\theta, \phi) = \mathcal{L}_{\text{obs}}(\theta, \phi) + \lambda_{\text{int}} \mathcal{L}_{\text{int}}(\theta, \phi) + \lambda_{\text{cf}} \mathcal{L}_{\text{cf}}(\theta, \phi) Minibatch SGD is used, with each iteration executing:

  1. Minibatch data sampling.
  2. Encoder forward pass; KK samples drawn for variational latents.
  3. Decoder GNN pass and ELBO computation.
  4. (Optional) Intervention sampling and repeated pass with modified data/graph.
  5. Monte Carlo gradient computation and parameter updates.

Complexity per minibatch is O(BdNhdegmaxK)O(B \cdot d \cdot N_h \cdot \deg_{\max} \cdot K), with efficient GPU-based message passing supported via libraries such as PyTorch Geometric.

6. Empirical Results and Comparison

Experiments span several synthetic SCMs (“collider,” “triangle,” “chain,” “M-graph”) and semi-synthetic finance and demographic datasets (“Loan,” “Adult”). Comparative baselines include MultiCVAE (independent node-wise CVAEs) and CAREFL (causal autoregressive normalizing flows). Evaluation metrics comprise Observational and Interventional Maximum Mean Discrepancy (MMD), mean-error of mean (MeanE), mean-error of standard deviation (StdE), Counterfactual MSE, and standard error (SSE).

Model Obs MMD Int MMD MeanE StdE CF MSE CF SSE
MultiCVAE 30.4±8.2 44.7±12.3 13.3±4.8 46.6±2.4 87.4±3.6 65.2±2.8
CAREFL 9.3±1.5 4.9±0.5 0.35±0.08 81.9±1.8 8.1±0.6 7.8±0.6
iVGAE (VACA) 1.5±0.7 1.6±0.4 0.75±0.31 42.0±0.3 9.9±0.7 7.1±0.4

iVGAE outperforms these baselines in observational and interventional MMD, closely reproduces moments of interventional distributions, and uniquely recovers the full variance structure owing to its architecture-induced factorization. Counterfactual error metrics are competitive, attesting to reliable counterfactual estimation.

A notable practical use-case is counterfactual fairness auditing and fair classifier learning, demonstrated on the German-Credit dataset: generating $\do$ samples enables auditing for fairness, and imposing fairness

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Interventional Variational Graph Auto-Encoder (iVGAE).