Nnet-survival: Neural Network Survival Model
- Nnet-survival is a neural network-based survival analysis method that partitions time into discrete intervals to model conditional hazards.
- The approach uses stochastic gradient descent to optimize likelihoods for censored and uncensored data, supporting both non-proportional and proportional hazards.
- Its scalable Keras implementation demonstrates competitive calibration and predictive performance on large-scale biomedical datasets.
Nnet-survival refers to a family of neural network-based models that parameterize survival and hazard functions over time partitions, enabling flexible, scalable, and deep learning-driven survival analysis. These models partition follow-up time into intervals, predict conditional hazard probabilities per interval, and optimize survival likelihood via modern deep learning frameworks using stochastic gradient descent. The resulting architecture generalizes classic discrete-time survival models and piecewise exponential models, accommodating complex covariate relationships and large-scale data.
1. Discrete-Time Partitioning and Model Parameterization
The nnet-survival approach begins by partitioning follow-up time into (or ) left-closed, right-open intervals, typically (Gensheimer et al., 2018, Holmer et al., 27 Mar 2024). For each individual , the conditional hazard probability for interval is defined: The survival function to the end of interval is: Alternatively, in piecewise constant-hazard models, time is indexed via such that . The hazard is parameterized as: where is the neural network representation of covariates, and are interval-specific weight vectors and biases (Holmer et al., 27 Mar 2024).
The model also admits a piecewise linear hazard parameterization: with produced via neural network outputs passed through activations for nonnegativity.
2. Likelihood Construction and Loss Functions
Survival likelihood is formulated using both censored and uncensored observations. The negative log-likelihood loss is constructed as follows (Gensheimer et al., 2018):
- For uncensored individual with event in interval :
- For right-censored individual at (interval ): Minibatch SGD optimizes total loss: Alternatively, for the continuous-time (piecewise constant hazard) formulation (Holmer et al., 27 Mar 2024): where is the cumulative hazard up to the observed time.
Label encoding uses two binary vectors per sample: indicates survival past interval , and denotes event in interval . The total likelihood is efficiently computed over minibatches and supports large-scale training.
3. Neural Network Architecture and Time-Varying Effects
Two design variants address temporal dependence:
- Flexible, non-proportional hazards: The final dense layer has outputs (logits), each with independent bias, followed by a sigmoid to yield ; baseline hazard and covariate effects are interval-specific.
- Proportional hazards: Covariate effect is constant across intervals; baseline hazard is interval-specific. The final output is:
A typical network pipeline (Gensheimer et al., 2018, Holmer et al., 27 Mar 2024):
- Input layer for covariates (numerical, categorical, image, or text embeddings).
- Hidden layers (dense or convolutional, with non-linear activations like ReLU or tanh).
- Survival output layer: either independent sigmoids (non-proportional) or procedural generation of interval-wise hazard by combining baseline and covariate terms.
Parameter scaling is (Holmer et al., 27 Mar 2024).
4. Training Procedures, Scalability, and Implementation
Optimization employs Adam, RMSprop, or SGD with batch sizes $32$–$256$. Training is performed with automatic differentiation over custom loss functions, facilitating rapid convergence and robustness across random restarts. Model hyperparameters are tuned via cross-validation, including interval count, hidden layer dimensionality, and regularization factors (e.g., weight decay).
Keras implementation utilizes standard Dense layers; for proportional hazards, a custom layer raises baseline odds to exponentiated covariate effects. Loss functions operate directly on survival status encodings. Because only the minibatch is required in memory, nnet-survival models scale efficiently—experiments demonstrate learning on samples on single CPUs without out-of-memory errors (Gensheimer et al., 2018).
Pseudocode illustrates these procedures, detailing data assembly, model construction, fitting, and prediction, including computation of survival curves and evaluation metrics (C-index, Brier scores).
5. Empirical Performance and Model Comparisons
Empirical validation encompasses simulated, image-based, and real-world datasets:
| Dataset | Nnet-survival (Flexible) | Cox-nnet | Deepsurv | Cox PH | Brier Score (NSurv) |
|---|---|---|---|---|---|
| Simulated (exp, n=5000) | C-index ~0.66 (perfect calibration) | - | - | - | - |
| MNIST (CNN, n=30,596/5,139) | Test C-index 0.713 | - | Oracle 0.770 | - | - |
| SUPPORT (n=9,105; 19 intervals) | C-index 0.732 | 0.735 | 0.730 | 0.734 | 0.181/0.184/0.177 |
Nnet-survival exhibits robust calibration and competitive C-index, outperforming Cox-nnet on scalability, with best Brier scores across benchmark intervals (Gensheimer et al., 2018).
Piecewise hazard networks (constant and linear) match the accuracy of energy-based models on simulated Weibull data (negative log-likelihood –$0.60$) and require approximately one-third the computation time (Holmer et al., 27 Mar 2024).
6. Connections to Established Survival Models
Nnet-survival generalizes and extends several approaches:
- Discrete-time logistic models (e.g., DeepHit): The probability of event per interval can be recovered by transforming hazard,
- Piecewise exponential models: Interval-wise hazard is a linear function of covariates; nnet-survival replaces this with an arbitrary neural network representation.
- Cox-nnet: Cox-nnet substitutes the linear Cox PH risk score with a shallow feed-forward neural network and optimizes the partial Cox likelihood (Garmire, 2020).
The flexible hazard parameterization in nnet-survival increases expressivity, while efficient minibatch likelihood calculation guarantees scalability and practical deployment in high-dimensional biomedical applications.
7. Practical Applications and Implementation Example
Deep neural approaches to survival—including nnet-survival—are adopted in clinical risk modeling, time-to-event prediction in medicine, and integration of multi-modal data (e.g., genomics + imaging). The nnet-survival Keras implementation exemplifies a fully differentiable workflow supporting customization for application-specific covariate structures, interval schemes, and regularization.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import numpy as np import tensorflow as tf from keras.layers import Input, Dense from keras.models import Model from keras import backend as K def nnet_surv_loss(n_intervals): def loss(y_true, y_pred): surv_s = y_true[:, :n_intervals] surv_f = y_true[:, n_intervals:] eps = K.epsilon() term1 = K.log(1. + surv_s * (y_pred - 1.) + eps) term2 = K.log(1. - surv_f * y_pred + eps) return -K.sum(term1 + term2, axis=1) return loss inputs = Input(shape=(n_covariates,)) x = Dense(64, activation='relu')(inputs) x = Dense(32, activation='relu')(x) surv_out = Dense(n_intervals, activation='sigmoid', name='surv_out')(x) model = Model(inputs, surv_out) model.compile(optimizer='rmsprop', loss=nnet_surv_loss(n_intervals)) model.fit(X_train, np.concatenate([surv_s_train, surv_f_train], axis=1), batch_size=128, epochs=50, validation_split=0.2) |
Nnet-survival provides a scalable, expressive, and theoretically principled neural methodology for survival analysis, directly extending discrete-time and piecewise exponential frameworks with neural network parameterizations of hazard. Its empirical efficacy on real and synthetic data, alongside modular implementation in leading deep learning libraries, underpins its role in modern risk modeling (Gensheimer et al., 2018, Holmer et al., 27 Mar 2024).