GraN-DAG: Neural DAG Learning
- GraN-DAG is a score-based framework utilizing neural networks to learn directed acyclic graphs that capture complex, nonlinear dependencies.
- It reformulates causal discovery as a constrained optimization problem with a continuous acyclicity constraint, generalizing linear methods like NOTEARS.
- Empirical evaluations on synthetic and real datasets show competitive performance in terms of SHD and SID, highlighting its practical utility.
GraN-DAG (“Gradient-Based Neural DAG Learning”) is a score-based framework for learning directed acyclic graphs (DAGs) from observational data, designed to handle complex, nonlinear dependencies between variables by parameterizing each conditional distribution with a neural network. By formulating the causal discovery task as a constrained optimization problem, GraN-DAG introduces a continuous, differentiable acyclicity constraint that enables direct optimization over the space of neural architectures. This approach generalizes linear methods such as NOTEARS to the fully nonlinear case, incorporating global optimization and differentiable structure learning in a single framework (Lachapelle et al., 2019).
1. Problem Formulation and Parameterization
GraN-DAG addresses the problem of recovering an unknown DAG over real-valued random variables from i.i.d. samples. Each conditional distribution is modeled with a neural network (NN) parameterized by weights , with input restricted (via masking) to the current set of putative parents of node . The architecture ensures that only the predicted parents influence the conditional via a sequence of weight matrices and element-wise nonlinearities.
A core innovation is the introduction of a continuous acyclicity constraint. Whereas in the linear setting (NOTEARS) acyclicity is enforced via , GraN-DAG generalizes this to the nonlinear case by defining for each node a connectivity matrix
where elementwise absolute values and binary mask matrices encode pruned inputs. The induced weighted adjacency matrix is assembled by summing over the entries in . The final acyclicity constraint is enforced through , holding if and only if the graph is acyclic.
The main objective is the maximization of the average log-likelihood across all nodes and samples, subject to acyclicity:
where
For Gaussian additive noise models (ANMs), this reduces to squared-error loss.
2. Optimization and Gradients
The constrained optimization is addressed via the augmented Lagrangian (AL) method, introducing multipliers and penalty parameters to construct an unconstrained problem:
and are updated as standard in AL methods: , is increased if fails to shrink.
The gradients required are as follows:
- For the conditional NNs, the likelihood gradient is standard and, in the Gaussian case, reduces to backpropagation of the squared residual.
- The acyclicity penalty gradient exploits matrix calculus, specifically,
and the chain rule, as the parameterization of in terms of NN weights is nonlinear and involves absolute-value path products, typically handled by automatic differentiation.
3. Neural Architectures and Implementation
In typical scenarios, each node’s conditional is modeled by a neural network with hidden layers of size 10 and leaky-ReLU activations. On data sets with higher risk of overfitting, is used. Weights are initialized via Xavier (Glorot) schemes, RMSprop is employed as the optimizer, and the learning rate is set to for the initial subproblem and thereafter.
Mask matrices are updated at each stage by hard-thresholding: any edge for which is permanently removed by zeroing out the input in the mask, ensuring sparsity and leading to more interpretable final graphs.
4. Algorithmic Workflow
The high-level GraN-DAG procedure iterates through AL subproblems. For each, it trains the network parameters using RMSprop with minibatches, performs early stopping based on a validation set, updates multipliers, and thresholds small edge weights. This process is repeated until the acyclicity constraint is met with high numerical precision. A final acyclic edge selection is performed based on Jacobian-based scores , with edges sorted and pruned until the DAG constraint is satisfied.
Subsequently, a regression-based pruning step analogous to CAM (using Generalized Additive Models and statistical testing) further trims non-significant edges. The end output is the estimated DAG corresponding to nonzero entries in .
5. Empirical Evaluation and Results
GraN-DAG was empirically validated on synthetic and real-world datasets using a suite of metrics:
- SHD (Structural Hamming Distance): counts added, deleted, and reversed edges.
- SID (Structural Intervention Distance): measures performance under single-node interventions.
- SHD-C: SHD over the CPDAG, accommodating methods that only return equivalence classes.
Experiments were conducted on synthetic random and scale-free graphs (ER1, ER4, SF1, SF4) with samples and , using Gaussian ANM, linear, additive function, and post-nonlinear (PNL) mechanisms. Realistic cases included the Sachs protein signaling network () and SynTReN gene regulation ().
Results demonstrate:
- For 10-node ER1 graphs, GraN-DAG yields SHD and SID , outperforming continuous methods (NOTEARS: SHD , DAG-GNN: SHD ) and competitive with greedy-search baselines (CAM: SHD ) (Lachapelle et al., 2019).
- On 50-node graphs, GraN-DAG maintains strong performance (SHD , SID ), while linear and nonlinear continuous baselines degrade.
- On real data (Sachs), GraN-DAG and its heteroskedastic extension (GraN-DAG++) achieve SHD , SHD-C , SID —comparable to CAM (SHD , SHD-C ) and better than NOTEARS or DAG-GNN.
Comparative results also indicate that GSF (kernel-based scores) performs worse than GraN-DAG, but better than simple random baselines.
6. Strengths, Limitations, and Potential Extensions
GraN-DAG introduces several advances for directed graph learning:
- Nonlinear structural equation modeling using universal function approximators (neural networks).
- Fully differentiable acyclicity constraint permits global optimization via gradient-based approaches, as opposed to discrete or greedy search procedures.
- The method is compatible with GPU acceleration and flexible network architectures.
However, some limitations are noted:
- The optimization landscape is non-convex; only stationary points are guaranteed.
- Algorithm performance depends on careful tuning of model and optimization hyperparameters, and overfitting risks are managed via early stopping and GAM-based pruning.
- Computational cost is cubic in the number of nodes due to the use of the matrix exponential, which may be prohibitive for extremely large graphs (though is tractable).
Potential extensions highlighted include:
- Adapting the framework to discrete, mixed, or other conditional exponential families.
- Investigating faster or alternative acyclicity constraints, e.g., fast approximations to .
- Leveraging partial ordering constraints or interventional data.
- Enhancing scalability to handle thousands of variables, likely via approximate computation of matrix exponentials or other block-coordinate methods.
GraN-DAG is thus situated as a competitive tool for causality and structure learning, especially for problems involving nonlinear, potentially high-dimensional dependencies, and offers a technically innovative alternative to existing continuous optimization and greedy search methods (Lachapelle et al., 2019).