Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
Gemini 2.5 Pro
GPT-5
GPT-4o
DeepSeek R1 via Azure
2000 character limit reached

Multi-Domain Causal Representation Learning via Weak Distributional Invariances (2310.02854v3)

Published 4 Oct 2023 in cs.LG and stat.ML

Abstract: Causal representation learning has emerged as the center of action in causal machine learning research. In particular, multi-domain datasets present a natural opportunity for showcasing the advantages of causal representation learning over standard unsupervised representation learning. While recent works have taken crucial steps towards learning causal representations, they often lack applicability to multi-domain datasets due to over-simplifying assumptions about the data; e.g. each domain comes from a different single-node perfect intervention. In this work, we relax these assumptions and capitalize on the following observation: there often exists a subset of latents whose certain distributional properties (e.g., support, variance) remain stable across domains; this property holds when, for example, each domain comes from a multi-node imperfect intervention. Leveraging this observation, we show that autoencoders that incorporate such invariances can provably identify the stable set of latents from the rest across different settings.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (45)
  1. Properties from mechanisms: an equivariance perspective on identifiable representation learning. arXiv preprint arXiv:2110.15796.
  2. Weakly supervised representation learning with sparse perturbations. arXiv preprint arXiv:2206.01101.
  3. Interventional causal representation learning. arXiv preprint arXiv:2209.11924.
  4. Invariant risk minimization. arXiv preprint arXiv:1907.02893.
  5. Recognition in terra incognita. In Proceedings of the European conference on computer vision (ECCV), pages 456–473.
  6. Provably learning object-centric representations. arXiv preprint arXiv:2305.14229.
  7. Weakly supervised causal representation learning. arXiv preprint arXiv:2203.16437.
  8. Sparks of artificial general intelligence: Early experiments with gpt-4. arXiv preprint arXiv:2303.12712.
  9. Learning linear causal representations from interventions under general nonlinear mixing. arXiv preprint arXiv:2306.02235.
  10. Triad constraints for learning causal structure of latent variables. Advances in neural information processing systems, 32.
  11. Comon, P. (1994). Independent component analysis, a new concept? Signal processing, 36(3):287–314.
  12. Hessian eigenmaps: Locally linear embedding techniques for high-dimensional data. Proceedings of the National Academy of Sciences, 100(10):5591–5596.
  13. Domain-adversarial training of neural networks. The journal of machine learning research, 17(1):2096–2030.
  14. The incomplete rosetta stone problem: Identifiability results for multi-view nonlinear ica. In Uncertainty in Artificial Intelligence, pages 217–227. PMLR.
  15. In search of lost domain generalization. arXiv preprint arXiv:2007.01434.
  16. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778.
  17. beta-VAE: Learning basic visual concepts with a constrained variational framework. In International Conference on Learning Representations.
  18. Nonlinear independent component analysis for principled disentanglement in unsupervised deep learning. arXiv preprint arXiv:2303.16535.
  19. Nonlinear ica using auxiliary variables and generalized contrastive learning. In The 22nd International Conference on Artificial Intelligence and Statistics, pages 859–868. PMLR.
  20. Learning nonparametric latent causal graphs with unknown interventions. arXiv preprint arXiv:2306.02899.
  21. Variational autoencoders and nonlinear ICA: A unifying framework. In International Conference on Artificial Intelligence and Statistics, pages 2207–2217. PMLR.
  22. Ice-beem: Identifiable conditional energy-based deep models based on nonlinear ICA. Advances in Neural Information Processing Systems, 33:12768–12778.
  23. Adam: A method for stochastic optimization. cite arxiv:1412.6980Comment: Published as a conference paper at the 3rd International Conference for Learning Representations, San Diego, 2015.
  24. Identifiability of deep generative models without auxiliary information. Advances in Neural Information Processing Systems, 35:15687–15701.
  25. Wilds: A benchmark of in-the-wild distribution shifts. In International Conference on Machine Learning, pages 5637–5664. PMLR.
  26. Partial disentanglement via mechanism sparsity. arXiv preprint arXiv:2207.07732.
  27. Additive decoders for latent variables identification and cartesian-product extrapolation. arXiv preprint arXiv:2307.02598.
  28. Disentanglement via mechanism sparsity regularization: A new principle for nonlinear ICA. In Conference on Causal Learning and Reasoning, pages 428–484. PMLR.
  29. Dall· e 2 fails to reliably capture common syntactic processes. Social Sciences & Humanities Open, 8(1):100648.
  30. Causal component analysis. arXiv preprint arXiv:2305.17225.
  31. icitris: Causal representation learning for instantaneous temporal effects. arXiv preprint arXiv:2206.06169.
  32. Citris: Causal identifiability from temporal intervened sequences. In International Conference on Machine Learning, pages 13557–13603. PMLR.
  33. Weakly-supervised disentanglement without compromises. In International Conference on Machine Learning, pages 6348–6359. PMLR.
  34. Object-centric causal representation learning. In NeurIPS 2022 Workshop on Symmetry and Geometry in Neural Representations.
  35. Domain generalization via invariant feature representation. In International conference on machine learning, pages 10–18. PMLR.
  36. Causal inference by using invariant prediction: identification and confidence intervals. Journal of the Royal Statistical Society Series B: Statistical Methodology, 78(5):947–1012.
  37. Toward causal representation learning. Proceedings of the IEEE, 109(5):612–634.
  38. Linear causal disentanglement via interventions. arXiv preprint arXiv:2211.16467.
  39. Understanding machine learning: From theory to algorithms. Cambridge University Press.
  40. Score-based causal representation learning with interventions. arXiv preprint arXiv:2301.08230.
  41. Nonparametric identifiability of causal representations from unknown interventions. arXiv preprint arXiv:2306.00542.
  42. Self-supervised learning with data augmentations provably isolates content from style. Advances in Neural Information Processing Systems, 34:16451–16467.
  43. Generalized independent noise condition for estimating latent variable causal graphs. Advances in neural information processing systems, 33:14891–14902.
  44. Learning temporally causal latent processes from general temporal data. In International Conference on Learning Representations.
  45. Identifiability guarantees for causal disentanglement from soft interventions. arXiv preprint arXiv:2307.06250.
Citations (7)

Summary

  • The paper introduces a method that leverages weak distributional invariances to disentangle stable causal factors across multiple domains.
  • It relaxes strong assumptions by allowing for imperfect interventions and flexible causal graphs, achieving block-affine identification.
  • Empirical results across various datasets demonstrate significant improvements in representation learning under diverse domain shifts.

This paper, "Multi-Domain Causal Representation Learning via Weak Distributional Invariances" (2310.02854), introduces a method for learning causal representations from unlabelled multi-domain data by leveraging weak distributional invariances. The core idea is that in many real-world scenarios, while some aspects of the data change across domains (e.g., background, style), others remain stable (e.g., object identity, certain physical properties). The proposed approach aims to identify and disentangle these stable latent factors from the unstable ones.

The authors relax common strong assumptions in causal representation learning, such as requiring perfect single-node interventions or a fixed causal graph (Directed Acyclic Graph - DAG) for all data points. Instead, they focus on the observation that a subset of latent variables might exhibit stable distributional properties (like support or variance) across different domains, even under multi-node imperfect interventions.

Problem Statement and Approach

The data generation process (DGP) is defined as follows: for each domain jj out of kk domains, latent variables zRdz \in \mathbb{R}^d are sampled from a domain-specific distribution pZ(j)p_Z^{(j)}. These latents are then transformed by an injective mixing function g:RdRng: \mathbb{R}^d \rightarrow \mathbb{R}^n to produce observations xRnx \in \mathbb{R}^n.

zpZ(j),xg(z)z \sim p_Z^{(j)}, \quad x \leftarrow g(z)

The goal is to learn an encoder f:RnRdf: \mathbb{R}^n \rightarrow \mathbb{R}^d such that its output z^=f(x)\hat{z} = f(x) is a good estimate of the true latent zz. This is typically done by training an autoencoder (f,h)(f,h) (where hh is the decoder) to satisfy the reconstruction identity hf(x)=xh \circ f(x) = x.

The key innovation is to divide the latent components zz into a stable set S\mathcal{S} and an unstable set U\mathcal{U}, so z=[zS,zU]z = [z_{\mathcal{S}}, z_{\mathcal{U}}]. The principle is that some functional FF of the marginal distribution of zSz_{\mathcal{S}}, i.e., F[pzS(j)]F[p_{z_{\mathcal{S}}^{(j)}}], remains invariant across domains jj. The proposed autoencoders incorporate this by enforcing a similar invariance on a subset S^\hat{\mathcal{S}} of the estimated latents z^\hat{z}:

hf(x)=x,xXh \circ f(x) = x, \quad \forall x \in \mathcal{X}

F[pz^S^(p)]=F[pz^S^(q)],pq,p,q[k]F[p_{\hat{z}_{\hat{\mathcal{S}}}^{(p)}}] = F[p_{\hat{z}_{\hat{\mathcal{S}}}^{(q)}}], \quad \forall p \neq q, p,q \in [k]

The learner can find a suitable S^\hat{\mathcal{S}} by starting with the largest possible set and reducing its size until a solution satisfying both reconstruction and invariance is found.

Theoretical Identification Guarantees

The paper provides theoretical guarantees for identifying the stable latents zSz_{\mathcal{S}} under different assumptions about the latent distribution pZp_Z and the mixing function gg.

1. Acyclic Structural Causal Models for pZp_Z

It's initially assumed that latents pZp_Z come from an acyclic Structural Causal Model (SCM). The approach first leverages prior results (Theorem 1, from (Ciliberto, 2020)) that show autoencoders with polynomial decoders (Constraint \ref{assm3: h_poly_new}) and polynomial mixing functions (Assumption \ref{assm1: dgp1}) achieve affine identification: z^=Az+c\hat{z} = Az + c.

  • Single-Node Imperfect Interventions (Theorem 2):
    • DGP: Latents are generated by zi(j)qi(zPa(i)(j))+ϱi(j)z_i^{(j)} \leftarrow q_i(z_{\mathrm{Pa}(i)}^{(j)}) + \varrho_i^{(j)}, where noise ϱi(j)\varrho_i^{(j)} can change across domains for iUi \in \mathcal{U}.
    • Assumption \ref{assm: imp_int_structure}: Only one node in U\mathcal{U} is imperfectly intervened in each interventional domain. Nodes in S\mathcal{S} are never intervened. Children of any node in U\mathcal{U} must also be in U\mathcal{U}.
    • Constraint \ref{assm: dist_inv} (Marginal Invariance): The marginal distribution pz^i(p)p_{\hat{z}_i^{(p)}} is enforced to be the same across domains for each iS^i \in \hat{\mathcal{S}}.
    • Result: Achieves block-affine identification, z^S^=DzS+e\hat{z}_{\hat{\mathcal{S}}} = D z_{\mathcal{S}} + e, meaning the learned stable latents are an affine transformation of the true stable latents, disentangled from zUz_{\mathcal{U}}.
  • Multi-Node Imperfect Interventions (Theorem 3):
    • Assumption \ref{assm: multi_int_str}: Allows for imperfect interventions on multiple nodes in U\mathcal{U} simultaneously. Assumes Gaussian noise ϱi\varrho_i with variances sampled i.i.d. from a non-atomic density. Requires sufficient random multi-node interventions.
    • Result: With high probability (if the number of interventions tt is large enough, scaling with dlog(d/δ)d\log(d/\delta)), block-affine identification z^S^=DzS+e\hat{z}_{\hat{\mathcal{S}}} = D z_{\mathcal{S}} + e is achieved.

2. General Distributions pZp_Z (Relaxing Fixed DAG)

This section studies scenarios where a single fixed DAG might not describe the relationships between latents across all data. A weaker form of invariance, marginal support invariance, is considered.

  • Polynomial Mixing (Theorem 4):
    • Assumption \ref{assm: supp_invar}: The minimum and maximum of each true latent ziz_i for iSi \in \mathcal{S} are invariant across domains.
    • Assumption \ref{assm: sup_var} (Support Variability): There exist two domains p,qp, q such that for each zZ(p)z \in \mathcal{Z}^{(p)}, there's a zZ(q)z' \in \mathcal{Z}^{(q)} where ziziz'_i \geq z_i for all ii, and zj>zjz'_j > z_j for unstable components jUj \in \mathcal{U}.
    • Constraint \ref{assm: supp_inv} (Marginal Support Invariance): The min/max of learned latents z^i\hat{z}_i for iS^i \in \hat{\mathcal{S}} are enforced to be invariant.
    • Result: If the affine transformation AiA_i (from z^i=AiTz+ci\hat{z}_i = A_i^T z + c_i) is in the positive orthant (Ai0A_i \succcurlyeq 0), then Air=0A_{ir}=0 for all rUr \in \mathcal{U}. This means z^i\hat{z}_i only depends on zSz_{\mathcal{S}}.
    • Implementation Consideration: Extending this beyond the positive orthant requires checking all 2d2^d orthants, potentially needing 2d+12^{d+1} domains. This can be reduced to dd domains if the support is a polytope and satisfies certain diversity conditions (Appendix \ref{sec:poly}).
  • General Diffeomorphisms (Theorem 5 - illustrated with two variables):
    • Considers z=[z1,z2]z = [z_1, z_2] where z1z_1's support is invariant ([0,1][0,1]) and z2z_2's support varies.
    • Definition \ref{def: lipschitz_a}: Defines a class of functions Γ\Gamma (parameterized by θ\theta) where the global minimum over [0,1]×[0,1][0,1]\times[0,1] is significantly different (by η\eta) from the minimum when z2z_2 is constrained to certain sub-intervals. Functions that depend only on z1z_1 are not in Γ\Gamma.
    • Assumption \ref{assm: diverse_int} (Support Variability for z2z_2): Requires that the support of z2z_2 in randomly drawn domains has a certain probability of being contained in small intervals or covering large portions like [κ,1κ][\kappa, 1-\kappa].
    • Result (Γc\Gamma^c identification): If enough diverse domains are sampled (number kN(δ,ε,η,ι)k \geq N(\delta, \varepsilon, \eta, \iota)), the learned map a1()a_1(\cdot) (where z^1=a1(z1,z2)\hat{z}_1 = a_1(z_1, z_2)) will not belong to Γ\Gamma. This pushes a1()a_1(\cdot) towards being a function of only z1z_1.
    • Implementation Consideration: The number of required domains NN depends on properties of the function class Γ\Gamma and the diversity of supports.

Learning Invariance-Constrained Representations in Practice

A two-stage learning procedure is proposed:

  1. Stage 1: Train an initial autoencoder (f~,h~)(\tilde{f}, \tilde{h}) to minimize reconstruction error: E[hf(x)x2]\mathbb{E}[\|h \circ f(x) - x\|^2]. Let x~=f~(x)\tilde{x} = \tilde{f}(x) be the output of this stage's encoder.
  2. Stage 2: Train a second autoencoder (f,h)(f^{\star}, h^{\star}) using x~\tilde{x} as input. The objective combines reconstruction error with a penalty term for violating the invariance:

    E[hf(x~)x~2]+λpenalty\mathbb{E}[\|h^{\star} \circ f^{\star}(\tilde{x}) - \tilde{x}\|^2] + \lambda \cdot \text{penalty}

    Two types of penalties are explored:

    • Min-Max Support Invariance Penalty (Equation \ref{eqn: pen_minmax}):

      pqiS^((minzZ~i(p)zminzZ~i(q)z)2+(maxzZ~i(p)zmaxzZ~i(q)z)2)\sum_{p \neq q} \sum_{i \in \hat{\mathcal{S}}} \left( (\min_{z \in \tilde{\mathcal{Z}}_i^{(p)}} z - \min_{z \in \tilde{\mathcal{Z}}_i^{(q)}} z)^2 + (\max_{z \in \tilde{\mathcal{Z}}_i^{(p)}} z - \max_{z \in \tilde{\mathcal{Z}}_i^{(q)}} z)^2 \right)

      where Z~i(p)\tilde{\mathcal{Z}}_i^{(p)} is the support of the ii-th component of f(x~)f^{\star}(\tilde{x}) in domain pp.

    • MMD-based Distribution Invariance Penalty (Equation \ref{eqn: pen_mmd}):

      pqMMD(pz^S^(p),pz^S^(q))\sum_{p \neq q} \mathrm{MMD}(p_{\hat{z}_{\hat{\mathcal{S}}}^{(p)}}, p_{\hat{z}_{\hat{\mathcal{S}}}^{(q)}})

      This measures the Maximum Mean Discrepancy between the joint distributions of the selected latent subset z^S^\hat{z}_{\hat{\mathcal{S}}} across domain pairs.

Empirical Findings

Experiments were conducted on four types of datasets with varying mixing functions gg and latent distributions pZp_Z:

  1. Linear mixing: x=Azx=Az.
  2. Polynomial mixing: g(z)g(z) is a polynomial function.
  3. Image rendering of balls: Latent variables are ball coordinates, gg is an image renderer.
  4. Unlabeled colored MNIST: Digits are zSz_{\mathcal{S}} (implicitly), color is zUz_{\mathcal{U}}.

For each, two types of pZp_Z were studied:

  • Independent latents: zSz_{\mathcal{S}} and zUz_{\mathcal{U}} are independent.
  • Dependent latents (Dynamic SCM - D-SCM): The SCM for latents varies across data points, inducing dependencies.

Implementation of Experiments:

  • The two-stage procedure was used.
  • For linear data, Stage 2 was applied directly.
  • For polynomial data (Stage 1: MLP encoder, polynomial decoder) and image data (Stage 1: ResNet encoder, ConvNet decoder), MLP autoencoders were used in Stage 2.
  • Three penalty variations were tested: Min-Max, MMD, and MMD + Min-Max.

Evaluation Metrics:

  • For synthetic/balls data: RS2R^2_{\mathcal{S}} (R-squared for predicting zSz_{\mathcal{S}} from z^S^\hat{z}_{\hat{\mathcal{S}}}) and RU2R^2_{\mathcal{U}} (R-squared for predicting zUz_{\mathcal{U}} from z^S^\hat{z}_{\hat{\mathcal{S}}}). Ideal is high RS2R^2_{\mathcal{S}} and low RU2R^2_{\mathcal{U}}.
  • For unlabeled colored MNIST: AccdigitsAcc_{\text{digits}} (accuracy of predicting digit from z^S^\hat{z}_{\hat{\mathcal{S}}}) and Rcolor2R^2_{\text{color}} (R-squared for predicting color from z^S^\hat{z}_{\hat{\mathcal{S}}}).

Key Results:

  • For linear and polynomial mixing, all three penalty types performed well in achieving block-affine disentanglement.
  • For the more complex ball-images and unlabeled colored MNIST datasets, the combination "MMD + Min-Max" penalty worked best.
  • The approach achieved notable disentanglement on the challenging unlabeled colored MNIST without using any labels during training.
  • Increasing the number of domains (kk) generally improved identification. The number of domains required for useful identification was often less than worst-case theoretical bounds. For instance, going from k=2k=2 to k=16k=16 showed significant improvements (Tables \ref{table5_results}, \ref{table6_results}).

Architectural Details for Experiments:

  • Polynomial Mixing (Stage 1):
    • Encoder: MLP (Input: nn/2n/2dn \rightarrow n/2 \rightarrow n/2 \rightarrow d) with LeakyReLU.
    • Decoder: Polynomial decoder with learnable coefficient matrix.
  • Balls Dataset (Stage 1):
    • Encoder: ResNet18.
    • Decoder: Standard deconvolutional layers.
    • Encoder output: 128-dim, invariance on first 64-dim.
  • Unlabeled Colored MNIST (Stage 1):
    • Encoder: Linear (784 \rightarrow 256 \rightarrow 256 \rightarrow 128) with ReLU & BatchNorm.
    • Decoder: Symmetric.
  • Unlabeled Colored MNIST (Stage 2):
    • Encoder: Linear (128 \rightarrow 200 \rightarrow 200 \rightarrow 200 \rightarrow 128) with LeakyReLU & BatchNorm.
    • Decoder: Symmetric.

Training Details:

  • Optimizer: Adam (lr=10310^{-3}, β1=0.9,β2=0.999\beta_1=0.9, \beta_2=0.999).
  • LR scheduler: Reduce on plateau (factor 0.5, patience 10 epochs, min_lr 10410^{-4}).
  • Batch size: 1024. Early stopping at 2000 steps.
  • Invariance penalty weight (λ\lambda): 1.0.
  • MMD kernel: RBF (bandwidth 1.0, adaptive for linear mixing).
  • Min-Max penalty: Sorted batch, top 10 values used for min/max robustness.

Conclusions

The paper significantly advances multi-domain causal representation learning by relaxing strong assumptions and introducing a framework based on weak distributional invariances. It demonstrates theoretically and empirically that autoencoders constrained by these invariances can identify stable latent factors from unstable ones under complex domain shifts, including multi-node imperfect interventions and settings where a fixed DAG does not govern the entire dataset. The proposed methods show promise for real-world applications where data comes from diverse sources with varying underlying conditions.

Dice Question Streamline Icon: https://streamlinehq.com

Follow-up Questions

We haven't generated follow-up questions for this paper yet.