CoxKAN Survival Analysis Model
- CoxKAN is a survival analysis model that uses Kolmogorov–Arnold Networks to parameterize the log-partial hazard function for rich, nonlinear risk modeling.
- It employs a compact network architecture with learnable univariate functions and B-spline bases, balancing expressivity and clear symbolic interpretability.
- Empirical evaluations show that CoxKAN outperforms classical Cox models and competes with deep neural networks by revealing biologically plausible nonlinear interactions in clinical and genomics data.
CoxKAN is a survival analysis model that leverages Kolmogorov–Arnold Networks (KANs) to provide a high-performance, interpretable alternative to traditional and deep learning-based survival models. CoxKAN directly parameterizes the log-partial hazard function of the Cox proportional hazards model using the compositional Kolmogorov–Arnold representation, allowing for rich, nonlinear modeling while maintaining explicit symbolic interpretability and performing inherent feature selection. Empirical studies show that CoxKAN consistently outperforms classical Cox proportional hazards models and is competitive with, or superior to, state-of-the-art deep neural network methods, especially in discovering complex multivariate dependencies in clinical and high-dimensional genomics data (Knottenbelt et al., 2024).
1. Mathematical Foundations and Model Definition
CoxKAN models the hazard function in the Cox proportional hazards framework as follows: where is the baseline hazard and is a real-valued function learned by a Kolmogorov–Arnold Network, mapping covariates to a log-risk score. The baseline hazard is unspecified and handled via the partial likelihood framework inherent to the Cox model; the learning focuses wholly on the nonparametric risk function .
The Kolmogorov–Arnold representation used in CoxKAN is: where the are univariate “inner” functions (one per feature per neuron) and the are univariate “outer” functions. The summation structure ensures universal approximation capability for continuous multivariate functions as per the Kolmogorov–Arnold theorem (Knottenbelt et al., 2024).
2. Network Architecture and Parameterization
CoxKAN typically employs a compact architecture (often one or two hidden layers) where each connection is a learnable univariate function rather than a scalar weight. The canonical (single hidden layer) architecture is:
- Inputs: features
- Hidden: units, each receiving all features via parallel univariate inner functions
- Output: as a sum over outer univariate functions
Each univariate function (inner or outer) is parameterized by a small B-spline basis plus an optional residual basis term: where are degree- B-spline basis functions, are trainable coefficients, and , are trainable scalars. By using low-order splines (typically ) on a small grid (–5), the architecture achieves a balance between expressivity and interpretability (Knottenbelt et al., 2024).
3. Training Objective, Regularization, and Feature Selection
CoxKAN is trained to minimize the regularized negative partial Cox log-likelihood: where
and is the risk set for event time .
The regularizer combines:
- -norm of activation magnitudes (encourages edge/neuron sparsity)
- Entropy of activation magnitudes (promotes focused sparse connectivity)
- -norm on spline coefficients (encourages function simplicity)
Optimization is performed using the Adam algorithm with early stopping based on validation concordance index (C-Index). After training, a threshold parameter prunes low-activation edges and neurons, resulting in automatic feature selection and topology simplification (Knottenbelt et al., 2024).
4. Symbolic Formula Extraction and Interpretability
After pruning, each univariate function is fitted to a small symbolic template: The best-fitting template is chosen by maximizing over empirical activations. If no template matches (), symbolic regression tools such as PySR are invoked to recover a closed-form expression.
The final risk score is the sum of these explicitly discovered symbolic curves. This approach provides direct insight into both the overall hazard model and the effect of individual covariates or interactions, differentiating CoxKAN from “black-box” neural competitors (Knottenbelt et al., 2024).
5. Empirical Evaluation and Benchmarking
CoxKAN was evaluated on four synthetic datasets (where ground-truth hazard formulas were known) and nine real-world datasets (comprising five standard clinical and four high-dimensional genomics cohorts). Performance was measured by the Harrell C-Index and, when applicable, the Integrated Brier Score.
| Dataset Type | Comparator Models | CoxKAN Performance |
|---|---|---|
| Synthetic | CoxPH, DeepSurv | Matches/exceeds true hazard in 3/4 cases |
| Clinical | CoxPH, DeepSurv | Outperforms CoxPH, matches/exceeds DeepSurv on 4/5 |
| Genomics (TCGA) | CoxPH+Lasso, DeepSurv | Competitive with CoxPH+Lasso; beats DeepSurv on 2/4 |
On synthetic benchmarks, CoxKAN exactly recovered the generating hazard function when expressible by the model. On clinical datasets, CoxKAN symbolic models achieved higher or comparable C-Index versus CoxPH and DeepSurv, with non-overlapping confidence intervals in several cases. In high-dimensional genomics, CoxKAN remained robust where unregularized CoxPH failed due to multicollinearity (Knottenbelt et al., 2024).
6. Discovery of Nonlinear Interactions and Biological Plausibility
CoxKAN demonstrated a unique capacity to discover and symbolize previously unrecognized nonlinear and interaction effects among covariates. For instance, in the SUPPORT dataset, the learned interaction subnetworks between age and metastatic cancer status revealed biologically plausible, cohort-specific risk trajectories. In the GBSG breast cancer dataset, CoxKAN rediscovered nonlinear “sweet-spot” biomarker effects, and in high-dimensional glioma genomics data, it uncovered clear genetic prognostic signatures matching known molecular pathology (Knottenbelt et al., 2024).
7. Practical Implementation and Usage Workflow
CoxKAN’s practical usage involves:
- Selecting the KAN architecture and regularization strength.
- Training the network using the regularized Cox partial-likelihood with early stopping.
- Pruning low-activation edges to yield a minimal feature set.
- Running symbolic fitting or symbolic regression on the remaining activations to produce a final, human-readable hazard model.
This enables practitioners to derive a sparse, accurate, and interpretable survival model that aligns with regulatory and scientific requirements for transparency in biomedical applications (Knottenbelt et al., 2024).
This summary synthesizes results and methodologies as presented in "CoxKAN: Kolmogorov-Arnold Networks for Interpretable, High-Performance Survival Analysis" (Knottenbelt et al., 2024).