Papers
Topics
Authors
Recent
Search
2000 character limit reached

Deep Structural Causal Models for Tractable Counterfactual Inference

Published 11 Jun 2020 in stat.ML and cs.LG | (2006.06485v2)

Abstract: We formulate a general framework for building structural causal models (SCMs) with deep learning components. The proposed approach employs normalising flows and variational inference to enable tractable inference of exogenous noise variables - a crucial step for counterfactual inference that is missing from existing deep causal learning methods. Our framework is validated on a synthetic dataset built on MNIST as well as on a real-world medical dataset of brain MRI scans. Our experimental results indicate that we can successfully train deep SCMs that are capable of all three levels of Pearl's ladder of causation: association, intervention, and counterfactuals, giving rise to a powerful new approach for answering causal questions in imaging applications and beyond. The code for all our experiments is available at https://github.com/biomedia-mira/deepscm.

Citations (208)

Summary

  • The paper introduces a deep learning-based framework for Structural Causal Models that enables tractable counterfactual inference.
  • It leverages normalizing flows for invertible likelihood estimation and variational inference to manage high-dimensional exogenous noise.
  • Experimental validation on synthetic and real-world datasets demonstrates potential applications in explainability, data augmentation, and personalized medicine.

Summary of "Deep Structural Causal Models for Tractable Counterfactual Inference" (2006.06485)

The paper "Deep Structural Causal Models for Tractable Counterfactual Inference" introduces a framework for building Structural Causal Models (SCMs) using deep learning components, enabling efficient counterfactual inference. It leverages normalizing flows and variational inference to perform tractable inference of exogenous noise variables, overcoming limitations of existing deep causal learning methods.

Framework Overview

The framework integrates SCMs with deep learning to model complex, high-dimensional datasets while supporting Pearl's three levels of causation: association, intervention, and counterfactuals. Traditional causal inference methods in fields like econometrics and epidemiology often rely on simpler, linear models, but this framework proposes deep learning interfaces as more flexible mechanisms.

Deep Structural Causal Models (DSCMs): They utilize deep learning to represent causal mechanisms within SCMs. This involves three types of mechanisms:

  • Invertible, explicit likelihood: Using normalizing flows for tractable maximum likelihood estimation.
  • Amortized, explicit likelihood: Employing variational inference when exact inversion isn't feasible.
  • Amortized, implicit likelihood: Training non-invertible mechanisms with adversarial objectives.

These mechanisms enable inference of the posterior distribution over latent variables necessary for counterfactual reasoning.

Counterfactual Inference

A DSCM capable of counterfactual inference aligns with Pearl's three-step causal reasoning: abduction, action, and prediction:

  • Abduction: Estimating the exogenous noise from observed data.
  • Action: Modifying the causal graph according to interventions.
  • Prediction: Sampling from the modified SCM to assess counterfactual outcomes.

The novelty lies in the framework's ability to efficiently infer the exogenous noise through deep learning techniques, a crucial component for generating realistic counterfactuals.

Experimental Validation

The paper validates the framework with experiments on synthetic and real-world datasets:

  • Morpho-MNIST Dataset: A synthetic dataset where stroke thickness and brightness influence digit images. The model captures the true causal relationships and generates plausible counterfactuals compared against known truths. Figure 1

Figure 1

Figure 1: Distributions of thickness and intensity in the true data (left), and learned by the full (centre) and conditional (right) models.

  • Brain MRI Scans: Applied to UK Biobank data, the model explores causal relationships among demographic covariates and anatomical features, demonstrating its utility in medical imaging contexts. Figure 2

    Figure 2: Original samples and counterfactuals from the full model. Counterfactuals preserve identity and style in low-density regions.

Implications and Future Directions

The research addresses core challenges in causal deep learning, promising improvements in model transparency, fairness, and robustness. Practical implications include:

  • Explainability: Counterfactual explanations can provide insightful causal analyses of ML predictions.
  • Data Augmentation: Counterfactuals extrapolate to novel data configurations, enhancing training datasets.
  • Personalized Medicine: Instance-level causal inference could lead to tailored medical interventions.

However, limitations include the assumption of fully observed data and potential issues with unobserved confounders. Future work could explore enhancing the framework to handle partial observability and further investigate causal discovery capabilities.

Conclusion

This work marks a significant advancement in deep causal inference by enabling tractable counterfactual reasoning within SCMs using modern deep learning techniques. Its successful application to both synthetic and real-world datasets highlights the broad potential for its use in diverse scientific domains.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.