- The paper introduces a method that leverages weak distributional invariances to disentangle stable causal factors across multiple domains.
- It relaxes strong assumptions by allowing for imperfect interventions and flexible causal graphs, achieving block-affine identification.
- Empirical results across various datasets demonstrate significant improvements in representation learning under diverse domain shifts.
This paper, "Multi-Domain Causal Representation Learning via Weak Distributional Invariances" (2310.02854), introduces a method for learning causal representations from unlabelled multi-domain data by leveraging weak distributional invariances. The core idea is that in many real-world scenarios, while some aspects of the data change across domains (e.g., background, style), others remain stable (e.g., object identity, certain physical properties). The proposed approach aims to identify and disentangle these stable latent factors from the unstable ones.
The authors relax common strong assumptions in causal representation learning, such as requiring perfect single-node interventions or a fixed causal graph (Directed Acyclic Graph - DAG) for all data points. Instead, they focus on the observation that a subset of latent variables might exhibit stable distributional properties (like support or variance) across different domains, even under multi-node imperfect interventions.
Problem Statement and Approach
The data generation process (DGP) is defined as follows: for each domain j out of k domains, latent variables z∈Rd are sampled from a domain-specific distribution pZ(j). These latents are then transformed by an injective mixing function g:Rd→Rn to produce observations x∈Rn.
z∼pZ(j),x←g(z)
The goal is to learn an encoder f:Rn→Rd such that its output z^=f(x) is a good estimate of the true latent z. This is typically done by training an autoencoder (f,h) (where h is the decoder) to satisfy the reconstruction identity h∘f(x)=x.
The key innovation is to divide the latent components z into a stable set S and an unstable set U, so z=[zS,zU]. The principle is that some functional F of the marginal distribution of zS, i.e., F[pzS(j)], remains invariant across domains j. The proposed autoencoders incorporate this by enforcing a similar invariance on a subset S^ of the estimated latents z^:
h∘f(x)=x,∀x∈X
F[pz^S^(p)]=F[pz^S^(q)],∀p=q,p,q∈[k]
The learner can find a suitable S^ by starting with the largest possible set and reducing its size until a solution satisfying both reconstruction and invariance is found.
Theoretical Identification Guarantees
The paper provides theoretical guarantees for identifying the stable latents zS under different assumptions about the latent distribution pZ and the mixing function g.
1. Acyclic Structural Causal Models for pZ
It's initially assumed that latents pZ come from an acyclic Structural Causal Model (SCM). The approach first leverages prior results (Theorem 1, from (Ciliberto, 2020)) that show autoencoders with polynomial decoders (Constraint \ref{assm3: h_poly_new}) and polynomial mixing functions (Assumption \ref{assm1: dgp1}) achieve affine identification: z^=Az+c.
- Single-Node Imperfect Interventions (Theorem 2):
- DGP: Latents are generated by zi(j)←qi(zPa(i)(j))+ϱi(j), where noise ϱi(j) can change across domains for i∈U.
- Assumption \ref{assm: imp_int_structure}: Only one node in U is imperfectly intervened in each interventional domain. Nodes in S are never intervened. Children of any node in U must also be in U.
- Constraint \ref{assm: dist_inv} (Marginal Invariance): The marginal distribution pz^i(p) is enforced to be the same across domains for each i∈S^.
- Result: Achieves block-affine identification, z^S^=DzS+e, meaning the learned stable latents are an affine transformation of the true stable latents, disentangled from zU.
- Multi-Node Imperfect Interventions (Theorem 3):
- Assumption \ref{assm: multi_int_str}: Allows for imperfect interventions on multiple nodes in U simultaneously. Assumes Gaussian noise ϱi with variances sampled i.i.d. from a non-atomic density. Requires sufficient random multi-node interventions.
- Result: With high probability (if the number of interventions t is large enough, scaling with dlog(d/δ)), block-affine identification z^S^=DzS+e is achieved.
2. General Distributions pZ (Relaxing Fixed DAG)
This section studies scenarios where a single fixed DAG might not describe the relationships between latents across all data. A weaker form of invariance, marginal support invariance, is considered.
- Polynomial Mixing (Theorem 4):
- Assumption \ref{assm: supp_invar}: The minimum and maximum of each true latent zi for i∈S are invariant across domains.
- Assumption \ref{assm: sup_var} (Support Variability): There exist two domains p,q such that for each z∈Z(p), there's a z′∈Z(q) where zi′≥zi for all i, and zj′>zj for unstable components j∈U.
- Constraint \ref{assm: supp_inv} (Marginal Support Invariance): The min/max of learned latents z^i for i∈S^ are enforced to be invariant.
- Result: If the affine transformation Ai (from z^i=AiTz+ci) is in the positive orthant (Ai≽0), then Air=0 for all r∈U. This means z^i only depends on zS.
- Implementation Consideration: Extending this beyond the positive orthant requires checking all 2d orthants, potentially needing 2d+1 domains. This can be reduced to d domains if the support is a polytope and satisfies certain diversity conditions (Appendix \ref{sec:poly}).
- General Diffeomorphisms (Theorem 5 - illustrated with two variables):
- Considers z=[z1,z2] where z1's support is invariant ([0,1]) and z2's support varies.
- Definition \ref{def: lipschitz_a}: Defines a class of functions Γ (parameterized by θ) where the global minimum over [0,1]×[0,1] is significantly different (by η) from the minimum when z2 is constrained to certain sub-intervals. Functions that depend only on z1 are not in Γ.
- Assumption \ref{assm: diverse_int} (Support Variability for z2): Requires that the support of z2 in randomly drawn domains has a certain probability of being contained in small intervals or covering large portions like [κ,1−κ].
- Result (Γc identification): If enough diverse domains are sampled (number k≥N(δ,ε,η,ι)), the learned map a1(⋅) (where z^1=a1(z1,z2)) will not belong to Γ. This pushes a1(⋅) towards being a function of only z1.
- Implementation Consideration: The number of required domains N depends on properties of the function class Γ and the diversity of supports.
Learning Invariance-Constrained Representations in Practice
A two-stage learning procedure is proposed:
- Stage 1: Train an initial autoencoder (f~,h~) to minimize reconstruction error: E[∥h∘f(x)−x∥2]. Let x~=f~(x) be the output of this stage's encoder.
- Stage 2: Train a second autoencoder (f⋆,h⋆) using x~ as input. The objective combines reconstruction error with a penalty term for violating the invariance:
E[∥h⋆∘f⋆(x~)−x~∥2]+λ⋅penalty
Two types of penalties are explored:
Min-Max Support Invariance Penalty (Equation \ref{eqn: pen_minmax}):
p=q∑i∈S^∑((z∈Z~i(p)minz−z∈Z~i(q)minz)2+(z∈Z~i(p)maxz−z∈Z~i(q)maxz)2)
where Z~i(p) is the support of the i-th component of f⋆(x~) in domain p.
MMD-based Distribution Invariance Penalty (Equation \ref{eqn: pen_mmd}):
p=q∑MMD(pz^S^(p),pz^S^(q))
This measures the Maximum Mean Discrepancy between the joint distributions of the selected latent subset z^S^ across domain pairs.
Empirical Findings
Experiments were conducted on four types of datasets with varying mixing functions g and latent distributions pZ:
- Linear mixing: x=Az.
- Polynomial mixing: g(z) is a polynomial function.
- Image rendering of balls: Latent variables are ball coordinates, g is an image renderer.
- Unlabeled colored MNIST: Digits are zS (implicitly), color is zU.
For each, two types of pZ were studied:
- Independent latents: zS and zU are independent.
- Dependent latents (Dynamic SCM - D-SCM): The SCM for latents varies across data points, inducing dependencies.
Implementation of Experiments:
- The two-stage procedure was used.
- For linear data, Stage 2 was applied directly.
- For polynomial data (Stage 1: MLP encoder, polynomial decoder) and image data (Stage 1: ResNet encoder, ConvNet decoder), MLP autoencoders were used in Stage 2.
- Three penalty variations were tested: Min-Max, MMD, and MMD + Min-Max.
Evaluation Metrics:
- For synthetic/balls data: RS2 (R-squared for predicting zS from z^S^) and RU2 (R-squared for predicting zU from z^S^). Ideal is high RS2 and low RU2.
- For unlabeled colored MNIST: Accdigits (accuracy of predicting digit from z^S^) and Rcolor2 (R-squared for predicting color from z^S^).
Key Results:
- For linear and polynomial mixing, all three penalty types performed well in achieving block-affine disentanglement.
- For the more complex ball-images and unlabeled colored MNIST datasets, the combination "MMD + Min-Max" penalty worked best.
- The approach achieved notable disentanglement on the challenging unlabeled colored MNIST without using any labels during training.
- Increasing the number of domains (k) generally improved identification. The number of domains required for useful identification was often less than worst-case theoretical bounds. For instance, going from k=2 to k=16 showed significant improvements (Tables \ref{table5_results}, \ref{table6_results}).
Architectural Details for Experiments:
- Polynomial Mixing (Stage 1):
- Encoder: MLP (Input: n→n/2→n/2→d) with LeakyReLU.
- Decoder: Polynomial decoder with learnable coefficient matrix.
- Balls Dataset (Stage 1):
- Encoder: ResNet18.
- Decoder: Standard deconvolutional layers.
- Encoder output: 128-dim, invariance on first 64-dim.
- Unlabeled Colored MNIST (Stage 1):
- Encoder: Linear (784 → 256 → 256 → 128) with ReLU & BatchNorm.
- Decoder: Symmetric.
- Unlabeled Colored MNIST (Stage 2):
- Encoder: Linear (128 → 200 → 200 → 200 → 128) with LeakyReLU & BatchNorm.
- Decoder: Symmetric.
Training Details:
- Optimizer: Adam (lr=10−3, β1=0.9,β2=0.999).
- LR scheduler: Reduce on plateau (factor 0.5, patience 10 epochs, min_lr 10−4).
- Batch size: 1024. Early stopping at 2000 steps.
- Invariance penalty weight (λ): 1.0.
- MMD kernel: RBF (bandwidth 1.0, adaptive for linear mixing).
- Min-Max penalty: Sorted batch, top 10 values used for min/max robustness.
Conclusions
The paper significantly advances multi-domain causal representation learning by relaxing strong assumptions and introducing a framework based on weak distributional invariances. It demonstrates theoretically and empirically that autoencoders constrained by these invariances can identify stable latent factors from unstable ones under complex domain shifts, including multi-node imperfect interventions and settings where a fixed DAG does not govern the entire dataset. The proposed methods show promise for real-world applications where data comes from diverse sources with varying underlying conditions.