Papers
Topics
Authors
Recent
2000 character limit reached

Nnet-survival: Neural Network Survival Model

Updated 15 December 2025
  • Nnet-survival is a neural network-based survival analysis method that partitions time into discrete intervals to model conditional hazards.
  • The approach uses stochastic gradient descent to optimize likelihoods for censored and uncensored data, supporting both non-proportional and proportional hazards.
  • Its scalable Keras implementation demonstrates competitive calibration and predictive performance on large-scale biomedical datasets.

Nnet-survival refers to a family of neural network-based models that parameterize survival and hazard functions over time partitions, enabling flexible, scalable, and deep learning-driven survival analysis. These models partition follow-up time into intervals, predict conditional hazard probabilities per interval, and optimize survival likelihood via modern deep learning frameworks using stochastic gradient descent. The resulting architecture generalizes classic discrete-time survival models and piecewise exponential models, accommodating complex covariate relationships and large-scale data.

1. Discrete-Time Partitioning and Model Parameterization

The nnet-survival approach begins by partitioning follow-up time into nn (or KK) left-closed, right-open intervals, typically [0,t1),[t1,t2),...,[tn1,tn)[0, t_1), [t_1, t_2), ..., [t_{n-1}, t_n) (Gensheimer et al., 2018, Holmer et al., 27 Mar 2024). For each individual pp, the conditional hazard probability for interval ii is defined: hp,iP(event in interval isurvived to ti1)h_{p, i} \equiv P(\text{event in interval } i \mid \text{survived to } t_{i-1}) The survival function to the end of interval jj is: Sp,j=i=1j(1hp,i)S_{p, j} = \prod_{i=1}^{j} (1 - h_{p, i}) Alternatively, in piecewise constant-hazard models, time is indexed via k(t)k(t) such that t[tk1,tk)t \in [t_{k-1}, t_k). The hazard is parameterized as: hk(x)=exp(wkTϕ(x)+bk)h_k(x) = \exp(w_k^T \phi(x) + b_k) where ϕ(x)\phi(x) is the neural network representation of covariates, and wk,bkw_k, b_k are interval-specific weight vectors and biases (Holmer et al., 27 Mar 2024).

The model also admits a piecewise linear hazard parameterization: hk(x,t)=ak(x)(ttk1)+ck(x)h_k(x, t) = a_k(x)(t-t_{k-1}) + c_k(x) with ak(x),ck(x)a_k(x), c_k(x) produced via neural network outputs passed through softplus\mathrm{softplus} activations for nonnegativity.

2. Likelihood Construction and Loss Functions

Survival likelihood is formulated using both censored and uncensored observations. The negative log-likelihood loss is constructed as follows (Gensheimer et al., 2018):

  • For uncensored individual pp with event in interval jpj_p: p=ln(hp,jp)+i=1jp1ln(1hp,i)\ell_p = \ln(h_{p, j_p}) + \sum_{i=1}^{j_p-1} \ln(1 - h_{p, i})
  • For right-censored individual at tpt_p (interval jpj_p): p=i=1jp1ln(1hp,i)\ell_p = \sum_{i=1}^{j_p-1} \ln(1 - h_{p, i}) Minibatch SGD optimizes total loss: Loss=p=1Pp\text{Loss} = -\sum_{p=1}^P \ell_p Alternatively, for the continuous-time (piecewise constant hazard) formulation (Holmer et al., 27 Mar 2024): i=δiloghk(τi)(xi)+H(τixi)\ell_i = -\delta_i \log h_{k(\tau_i)}(x_i) + H(\tau_i \mid x_i) where H(τixi)H(\tau_i \mid x_i) is the cumulative hazard up to the observed time.

Label encoding uses two binary vectors per sample: survs[i]\text{surv}_s[i] indicates survival past interval ii, and survf[i]\text{surv}_f[i] denotes event in interval ii. The total likelihood is efficiently computed over minibatches and supports large-scale training.

3. Neural Network Architecture and Time-Varying Effects

Two design variants address temporal dependence:

  • Flexible, non-proportional hazards: The final dense layer has nn outputs (logits), each with independent bias, followed by a sigmoid to yield 1hp,i1-h_{p, i}; baseline hazard and covariate effects are interval-specific.
  • Proportional hazards: Covariate effect is constant across intervals; baseline hazard is interval-specific. The final output is: y^p[i]=s0,iexp(ηp),ηp=Xβ\hat{y}_p[i] = s_{0, i}^{\exp(\eta_p)}, \quad \eta_p = X\beta

A typical network pipeline (Gensheimer et al., 2018, Holmer et al., 27 Mar 2024):

  • Input layer for covariates (numerical, categorical, image, or text embeddings).
  • Hidden layers (dense or convolutional, with non-linear activations like ReLU or tanh).
  • Survival output layer: either nn independent sigmoids (non-proportional) or procedural generation of interval-wise hazard by combining baseline and covariate terms.

Parameter scaling is O(dH+H2+HK)O(d \cdot H + H^2 + H \cdot K) (Holmer et al., 27 Mar 2024).

4. Training Procedures, Scalability, and Implementation

Optimization employs Adam, RMSprop, or SGD with batch sizes $32$–$256$. Training is performed with automatic differentiation over custom loss functions, facilitating rapid convergence and robustness across random restarts. Model hyperparameters are tuned via cross-validation, including interval count, hidden layer dimensionality, and regularization factors (e.g., L2L_2 weight decay).

Keras implementation utilizes standard Dense layers; for proportional hazards, a custom layer raises baseline odds to exponentiated covariate effects. Loss functions operate directly on survival status encodings. Because only the minibatch is required in memory, nnet-survival models scale efficiently—experiments demonstrate learning on >106>10^6 samples on single CPUs without out-of-memory errors (Gensheimer et al., 2018).

Pseudocode illustrates these procedures, detailing data assembly, model construction, fitting, and prediction, including computation of survival curves and evaluation metrics (C-index, Brier scores).

5. Empirical Performance and Model Comparisons

Empirical validation encompasses simulated, image-based, and real-world datasets:

Dataset Nnet-survival (Flexible) Cox-nnet Deepsurv Cox PH Brier Score (NSurv)
Simulated (exp, n=5000) C-index ~0.66 (perfect calibration) - - - -
MNIST (CNN, n=30,596/5,139) Test C-index 0.713 - Oracle 0.770 - -
SUPPORT (n=9,105; 19 intervals) C-index 0.732 0.735 0.730 0.734 0.181/0.184/0.177

Nnet-survival exhibits robust calibration and competitive C-index, outperforming Cox-nnet on scalability, with best Brier scores across benchmark intervals (Gensheimer et al., 2018).

Piecewise hazard networks (constant and linear) match the accuracy of energy-based models on simulated Weibull data (negative log-likelihood 0.56\approx 0.56–$0.60$) and require approximately one-third the computation time (Holmer et al., 27 Mar 2024).

6. Connections to Established Survival Models

Nnet-survival generalizes and extends several approaches:

  • Discrete-time logistic models (e.g., DeepHit): The probability of event per interval can be recovered by transforming hazard,

    pk=1exp(hkΔtk)p_k = 1 - \exp(-h_k\,\Delta t_k)

  • Piecewise exponential models: Interval-wise hazard is a linear function of covariates; nnet-survival replaces this with an arbitrary neural network representation.
  • Cox-nnet: Cox-nnet substitutes the linear Cox PH risk score with a shallow feed-forward neural network and optimizes the partial Cox likelihood (Garmire, 2020).

The flexible hazard parameterization in nnet-survival increases expressivity, while efficient minibatch likelihood calculation guarantees scalability and practical deployment in high-dimensional biomedical applications.

7. Practical Applications and Implementation Example

Deep neural approaches to survival—including nnet-survival—are adopted in clinical risk modeling, time-to-event prediction in medicine, and integration of multi-modal data (e.g., genomics + imaging). The nnet-survival Keras implementation exemplifies a fully differentiable workflow supporting customization for application-specific covariate structures, interval schemes, and regularization.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy as np
import tensorflow as tf
from keras.layers import Input, Dense
from keras.models import Model
from keras import backend as K

def nnet_surv_loss(n_intervals):
    def loss(y_true, y_pred):
        surv_s = y_true[:, :n_intervals]
        surv_f = y_true[:, n_intervals:]
        eps = K.epsilon()
        term1 = K.log(1. + surv_s * (y_pred - 1.) + eps)
        term2 = K.log(1. - surv_f * y_pred + eps)
        return -K.sum(term1 + term2, axis=1)
    return loss

inputs = Input(shape=(n_covariates,))
x = Dense(64, activation='relu')(inputs)
x = Dense(32, activation='relu')(x)
surv_out = Dense(n_intervals, activation='sigmoid', name='surv_out')(x)
model = Model(inputs, surv_out)
model.compile(optimizer='rmsprop', loss=nnet_surv_loss(n_intervals))
model.fit(X_train, np.concatenate([surv_s_train, surv_f_train], axis=1), batch_size=128, epochs=50, validation_split=0.2)
Survival probabilities per interval are output, and survival curves (or pointwise metrics such as C-index) are derived via cumulative products for downstream evaluation and analysis (Gensheimer et al., 2018).


Nnet-survival provides a scalable, expressive, and theoretically principled neural methodology for survival analysis, directly extending discrete-time and piecewise exponential frameworks with neural network parameterizations of hazard. Its empirical efficacy on real and synthetic data, alongside modular implementation in leading deep learning libraries, underpins its role in modern risk modeling (Gensheimer et al., 2018, Holmer et al., 27 Mar 2024).

Whiteboard

Follow Topic

Get notified by email when new papers are published related to Nnet-survival Model.