Papers
Topics
Authors
Recent
Search
2000 character limit reached

STRAFE: Transformer-Based Survival Analysis

Updated 9 February 2026
  • STRAFE is a deep-learning architecture for discrete-time survival analysis that models longitudinal EHR visits using transformer-based sequence modeling and convolution to predict time-to-event outcomes.
  • It leverages pre-trained OMOP embeddings and sinusoidal timestamps to handle irregular event intervals and accommodate right-censored data effectively.
  • Evaluations on CKD patient cohorts demonstrate significant performance gains, with reduced MAE and improved AUC compared to classical and deep-learning baselines.

STRAFE is a deep-learning architecture for time-to-event (survival) prediction from longitudinal electronic health records (EHR), leveraging transformer-based sequence modeling and designed to handle irregular event intervals and censored outcomes. STRAFE was introduced for predicting chronic kidney disease (CKD) progression using large-scale claims data and demonstrated improved performance in both event-time and fixed-time risk prediction tasks compared to classical and deep-learning baselines (Zisser et al., 2023).

1. Architecture Overview

STRAFE models a patient's visit history as an ordered set of timestamped EHR visits, each consisting of a collection of OMOP concepts. Each visit VjiV_j^i for patient %%%%1%%%% at time tjit_j^i includes codes CjiC_j^i encoded through a pre-trained skip-gram embedding ϕ:CRde\phi: C \rightarrow \mathbb{R}^{d_e} with de=128d_e=128. The content embedding per visit is ψ(Vji)=cCjiϕ(c)\psi(V_j^i) = \sum_{c \in C_j^i} \phi(c). To capture irregular visit intervals, each timestamp is mapped via a sinusoidal embedding τ(Vji)=[sin(t~jiω)cos(t~jiω)]\tau(V_j^i) = [\sin(\tilde{t}_j^i \cdot \omega) \,\|\, \cos(\tilde{t}_j^i \cdot \omega)], with ω\omega a geometric progression of frequencies.

The per-visit encoding xj=ψ(Vji)+τ(Vji)x_j = \psi(V_j^i) + \tau(V_j^i) yields the visit sequence x=[x1;;xnv]Rnv×dex = [x_1; \ldots; x_{n_v}] \in \mathbb{R}^{n_v \times d_e}, where nv=100n_v = 100 is the max sequence length (truncated/padded). This sequence is processed by L=1L=1 layer of multi-head self-attention (H=4H=4 heads, dropout p=0.3p=0.3), producing ZRnv×deZ \in \mathbb{R}^{n_v \times d_e}.

Subsequently, ZZ is projected onto a fixed monthly time grid of Tmax=48T_{max}=48 months via a 1D convolution, yielding URTmax×deU \in \mathbb{R}^{T_{max} \times d_e}. Temporal embeddings are added for each time point, and a second self-attention block (L=1L'=1, H=4H'=4) models dependencies across months. A two-layer MLP, applied to each time point, outputs q(tX)=1λ(tX)q(t|X) = 1 - \lambda(t|X), the complement of the discrete monthly hazard.

The discrete-time survival function is given by:

S(tX)=u=0tq(uX)S(t|X) = \prod_{u=0}^{t} q(u|X)

and the mean predicted time-to-event is μ^=t=0TmaxS(tX)\hat{\mu} = \sum_{t=0}^{T_{max}} S(t|X).

2. Survival Modeling and Loss Function

STRAFE is formulated for discrete-time survival analysis and accommodates right-censored data through a custom likelihood. For each patient ii, outcomes are encoded as (Ti,δi)(T_i, \delta_i), where TiT_i is the event or censoring time and δi\delta_i is the event indicator. The joint loss over a batch is

L=i:δi=1Liobs+i:δi=0LicensL = \sum_{i:\delta_i=1} L_i^{obs} + \sum_{i:\delta_i=0} L_i^{cens}

with observed-case loss:

Lobs=t=0T1logS^(tX)t=TTmaxlog[1S^(tX)]L^{obs} = -\sum_{t=0}^{T-1} \log \hat{S}(t|X) - \sum_{t=T}^{T_{max}} \log[1 - \hat{S}(t|X)]

and censored-case loss:

Lcens=t=0T1logS^(tX)L^{cens} = -\sum_{t=0}^{T-1} \log \hat{S}(t|X)

This framework does not assume proportional hazards and optimizes the discrete event likelihood across both observed and censored outcomes.

3. Training Protocol and Implementation

STRAFE employs OMOP concept embeddings pre-trained on 35 million claims using skip-gram word2vec (window=90 days, de=128d_e=128, vocabulary size ≈36,000). Patient sequences are standardized to nv=100n_v=100 visits. Model selection was guided via grid search. Key hyperparameters include:

  • First self-attention: H=4H=4, L=1L=1, dropout=0.3
  • 1D convolution projects nvn_v visits to Tmax=48T_{max}=48 monthly bins
  • Second self-attention: H=4H'=4, L=1L'=1
  • Optimization: Adam, batch size 256, learning rate 2×1032\times10^{-3}, no explicit weight decay or momentum
  • Implementation framework: PyTorch (Tesla K80 GPU); comparison baselines in scikit-survival and Pycox

These choices were intended to maximize out-of-sample predictive performance without adding explicit architectural complexity or heavy regularization.

4. Performance Evaluation

STRAFE was evaluated on a real-world dataset of over 130,000 CKD stage 3 patients, with the following results on a held-out cohort (n27,000n \approx 27,000):

Time-to-event prediction (48-month horizon):

Model C-index MAE (mo)
RSF (BOW) 0.609 32.33
RSF (emb) 0.719 31.85
DeepHit (BOW) 0.580 28.39
DeepHit (emb) 0.714 28.59
STRAFE 0.710 22.16
STRAFE-LSTM 0.710 21.59
Uncont. STRAFE 0.690 22.14
Uncont. LSTM 0.711 23.04

Fixed-time risk (AUC-ROC at 6, 12, 24 months):

Model 6 mo 12 mo 24 mo
LR (BOW) 0.622 0.598 0.603
LR (emb) 0.711 0.710 0.720
SARD 0.725 0.731 0.748
RSF (emb) 0.719 0.723 0.683
DeepHit (emb) 0.729 0.728 0.725
STRAFE 0.751 0.754 0.764

STRAFE reduced mean absolute error (MAE) by approximately 25% compared to DeepHit, and improved AUC by 2–3 points over SARD (p ≈ 2×1082 \times 10^{-8} at 24 months). Embedding usage was the principal driver of performance gains in C-index. Subgroup analysis indicated AUCs rising to ~0.80 in patients <60 years, with male AUC 0.761 vs female 0.748.

In top-decile risk stratification, STRAFE achieved a PPV of 20.9% at 12 months (vs base rate 6.67%, a 3x lift) and 28.4% at 24 months (vs 14.98%).

5. Model Interpretability and Visualization

STRAFE's first self-attention matrix ARnv×nvA \in \mathbb{R}^{n_v \times n_v}, defined as A=softmax(QK/dk)A = \operatorname{softmax}(QK^\top / \sqrt{d_k}), quantifies visit-to-visit relatedness. High-attention visit pairs can be visualized as graph nodes (visits colored by dominant ICD chapter) with edge weights reflecting attention scores. In documented ablation studies, removing highly weighted visits caused substantial shifts in predicted survival, confirming their importance for outcome risk.

This mechanism enables per-patient explanation by highlighting the visits most influential for time-to-deterioration predictions. This suggests utility for targeted clinical interventions, e.g. attribution of risk to hypertension versus respiratory events in CKD management.

6. Clinical Impact and Use Cases

STRAFE's ability to model right-censored survival, to exploit sequence structure in visit-level health data, and to provide highly granular per-patient predictions supports its application in intervention targeting for chronic disease management. The threefold lift in PPV among top-decile risk patients, and improved MAE in time-to-event prediction, demonstrate advantages over both classical survival forests and prior deep-learning methods. Its explainability facilitates clinical deployment by providing actionable patient-specific risk drivers and supporting transparent decision support in high-stakes medical settings (Zisser et al., 2023).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to STRAFE.