Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
144 tokens/sec
GPT-4o
11 tokens/sec
Gemini 2.5 Pro Pro
47 tokens/sec
o3 Pro
5 tokens/sec
GPT-4.1 Pro
4 tokens/sec
DeepSeek R1 via Azure Pro
33 tokens/sec
2000 character limit reached

Toward Temporal Causal Representation Learning with Tensor Decomposition (2507.14126v1)

Published 18 Jul 2025 in cs.LG, cs.AI, and stat.ML

Abstract: Temporal causal representation learning is a powerful tool for uncovering complex patterns in observational studies, which are often represented as low-dimensional time series. However, in many real-world applications, data are high-dimensional with varying input lengths and naturally take the form of irregular tensors. To analyze such data, irregular tensor decomposition is critical for extracting meaningful clusters that capture essential information. In this paper, we focus on modeling causal representation learning based on the transformed information. First, we present a novel causal formulation for a set of latent clusters. We then propose CaRTeD, a joint learning framework that integrates temporal causal representation learning with irregular tensor decomposition. Notably, our framework provides a blueprint for downstream tasks using the learned tensor factors, such as modeling latent structures and extracting causal information, and offers a more flexible regularization design to enhance tensor decomposition. Theoretically, we show that our algorithm converges to a stationary point. More importantly, our results fill the gap in theoretical guarantees for the convergence of state-of-the-art irregular tensor decomposition. Experimental results on synthetic and real-world electronic health record (EHR) datasets (MIMIC-III), with extensive benchmarks from both phenotyping and network recovery perspectives, demonstrate that our proposed method outperforms state-of-the-art techniques and enhances the explainability of causal representations.

Summary

  • The paper introduces CaRTeD, a framework that jointly infers latent phenotypes and their temporal causal relationships from irregular EHR data.
  • It employs a block coordinate descent with ADMM, integrating PARAFAC2-based tensor decomposition and dynamic Bayesian networks to enforce sparsity and acyclicity.
  • Empirical and theoretical analyses demonstrate superior tensor recovery and causal graph accuracy compared to standard methods on both synthetic and real-world datasets.

Temporal Causal Representation Learning with Tensor Decomposition: An Expert Overview

This paper introduces CaRTeD, a joint-learning framework for temporal causal representation learning and irregular tensor decomposition, with a primary application to electronic health record (EHR) data. The work addresses the challenge of uncovering temporal causal structures among latent clusters (e.g., phenotypes) in high-dimensional, irregularly sampled time series, a setting common in real-world healthcare and other domains.

Motivation and Problem Setting

Traditional causal representation learning (CRL) methods, including those based on dynamic Bayesian networks (DBNs), are typically designed for flat, regularly sampled data and do not scale to high-dimensional, irregular tensor data. Conversely, tensor decomposition methods such as PARAFAC2 and its constrained variants (e.g., COPA) are effective for extracting latent clusters (phenotypes) from irregular tensors but do not incorporate causal structure learning. This disconnect limits the interpretability and downstream utility of both approaches, especially in domains like computational phenotyping from EHRs, where both latent structure and causal relationships are of interest.

The paper formalizes the problem as follows: Given a collection of high-dimensional, irregular tensors (e.g., patient × diagnosis × time), the goal is to jointly (1) extract meaningful latent clusters (phenotypes) and (2) infer both contemporaneous and temporal causal relationships among these clusters.

Methodological Contributions

The core methodological innovation is the integration of temporal causal structure learning with irregular tensor decomposition in a unified optimization framework. The key elements are:

  • Irregular Tensor Decomposition: The framework builds on PARAFAC2, which accommodates one mode with varying dimension (e.g., different numbers of visits per patient). The decomposition yields patient-specific temporal trajectories, phenotype loadings, and a shared feature-to-phenotype mapping.
  • Causal Structure Learning: The latent phenotype trajectories are modeled as following a DBN, with both intra-slice (contemporaneous) and inter-slice (temporal) edges. The causal structure is parameterized by matrices WW (intra-slice) and A(p)A^{(p)} (inter-slice, for lag pp).
  • Joint Optimization: The objective combines tensor reconstruction loss with a causal regularization term, enforcing that the latent trajectories are well-explained by the learned DBN. Sparsity is encouraged via 1\ell_1 penalties on WW and A(p)A^{(p)}, and acyclicity is enforced using a differentiable constraint.
  • Block Coordinate Descent with ADMM: The non-convex, multi-block optimization is solved via block coordinate descent, with each block (tensor factors, causal parameters) updated using ADMM or closed-form solutions where possible. The approach alternates between updating tensor factors (with causal regularization) and updating the causal structure (given current latent trajectories).

Theoretical Analysis

A significant theoretical contribution is the convergence analysis of the proposed block coordinate descent scheme under non-convex constraints, particularly for the PARAFAC2 block with causal regularization. The analysis establishes that, under mild conditions and sufficiently large penalty parameters, the algorithm converges to a stationary point of the augmented Lagrangian. This fills a gap in the literature regarding convergence guarantees for irregular tensor decomposition with complex constraints.

Empirical Evaluation

Synthetic Data

Experiments on simulated irregular tensors with known ground-truth causal structure demonstrate that CaRTeD outperforms both sequential (two-step) baselines and state-of-the-art tensor decomposition methods (e.g., COPA) on both tensor recovery and causal graph recovery metrics. Notably:

  • Tensor Recovery: CaRTeD achieves higher cross-product invariance (CPI), similarity (SIM), and recovery rate (RR) compared to COPA, especially under increasing noise.
  • Causal Graph Recovery: CaRTeD yields lower structural Hamming distance (SHD), lower false discovery rate (FDR), and higher true positive rate (TPR) for both intra-slice and inter-slice networks, particularly as the number of patients increases.

Real-World EHR Data (MIMIC-III)

Applied to the MIMIC-III dataset, CaRTeD extracts clinically meaningful phenotypes and infers a causal phenotype network that aligns with established medical knowledge. The inferred network captures known relationships such as hypertension leading to kidney disease and heart failure leading to respiratory failure, with edge directions and presence validated against the literature. In contrast, benchmark methods either miss key edges or infer implausible directions.

Practical Implications

Implementation Considerations

  • Computational Requirements: The block coordinate descent with ADMM is computationally intensive, especially for large tensors and high-rank decompositions. However, the algorithm is amenable to parallelization across patients and can leverage GPU acceleration for tensor operations.
  • Initialization: Warm-starting the shared feature-to-phenotype matrix VV (e.g., via unconstrained tensor decomposition) improves both convergence speed and solution quality.
  • Hyperparameter Tuning: Regularization parameters for sparsity and penalty parameters for ADMM require tuning, typically via cross-validation or based on prior knowledge of expected graph sparsity.
  • Scalability: The method scales to thousands of patients and hundreds of features, as demonstrated on MIMIC-III, but may require further engineering for very large-scale deployments.

Limitations and Extensions

  • Stationarity Assumption: The current framework assumes a single, time-invariant DBN structure across all patients and time points. Extensions to time-varying or subgroup-specific causal structures are possible and would increase applicability to heterogeneous populations.
  • Linearity: The causal model is linear-Gaussian; incorporating nonlinear or non-Gaussian dynamics (e.g., via neural networks or Gaussian processes) is a promising direction.
  • Mixed Data Types: The current approach assumes continuous-valued latent trajectories; extensions to mixed discrete-continuous data would broaden applicability.

Implications for Future Research

This work establishes a blueprint for integrating causal discovery with unsupervised representation learning in high-dimensional, irregular time series. The approach is directly applicable to computational phenotyping, disease progression modeling, and other domains where both latent structure and causal relationships are of interest. Future research directions include:

  • Nonlinear and Nonstationary Causal Models: Incorporating more expressive models for latent dynamics and allowing for evolving causal structures.
  • Federated and Privacy-Preserving Extensions: Adapting the framework for distributed settings, as in federated EHR analysis.
  • Automated Model Selection: Developing principled methods for selecting the number of phenotypes (rank) and regularization parameters.
  • Broader Applications: Applying the framework to other domains with irregular tensor data, such as sensor networks, genomics, and recommender systems.

Conclusion

CaRTeD represents a significant advance in temporal causal representation learning for irregular tensor data. By jointly learning latent clusters and their causal relationships, the framework enhances both the interpretability and utility of unsupervised phenotyping, with strong empirical performance and theoretical guarantees. The methodology and analysis provide a foundation for further developments in causal representation learning for complex, high-dimensional time series.

Dice Question Streamline Icon: https://streamlinehq.com

Follow-up Questions

We haven't generated follow-up questions for this paper yet.

X Twitter Logo Streamline Icon: https://streamlinehq.com