Residual GRU+MHSA
- The paper demonstrates superior performance in cardiovascular disease detection, achieving 86.1% accuracy with robust ranking metrics on the UCI Heart Disease dataset.
- The model integrates residual bidirectional GRU stacks with channel reweighting and multi-head self-attention pooling to capture both sequential and global contextual relationships in clinical records.
- The architecture is computationally efficient with under 1.2 million parameters, enabling fast real-time inference suitable for resource-constrained healthcare devices.
The Residual GRU+MHSA model is a compact deep learning architecture specifically designed for predictive modeling with tabular clinical records. It integrates residual bidirectional Gated Recurrent Units (GRUs), channel reweighting, and Multi-Head Self-Attention (MHSA) pooling with a learnable classification token. This hybrid recurrent-attention-based model processes each clinical record as a “pseudo-sequence” of its individual feature columns, enabling both sequential modeling and global contextual aggregation. Evaluated on the UCI Heart Disease dataset with 5-fold cross-validation, it demonstrates superior performance over classical machine learning and modern deep learning baselines for cardiovascular disease detection (Dash et al., 16 Dec 2025).
1. Model Architecture and Module Interactions
The Residual GRU+MHSA model comprises four sequential modules:
- Input Embedding & Feature Dropout: Each scalar feature (, with the number of features) is linearly projected to a -dimensional embedding . Column-wise dropout zeroes embedding vectors during training with probability to mitigate feature overfitting.
- Residual Bidirectional GRU Stack: The embedded pseudo-sequence is processed by an initial bidirectional GRU (BiGRU) layer, with outputs concatenated to dimension . A learnable projection matches dimensionality, forming a residual skip pathway:
This is followed by identical residual BiGRU blocks:
- Channel Reweighting (SE-Style Gating): The final sequence undergoes temporal mean pooling to compute a global summary , which is passed through two small fully connected layers and a gating sigmoid to obtain per-channel weights . Each time-step is reweighted:
- Multi-Head Self-Attention Pooling (MHSA) with CLS Token: A learnable CLS token attends to the reweighted sequence over self-attention layers, iteratively updating its embedding through scaled dot-product attention. The normalized final CLS token is input to a prediction MLP head, yielding scalar output and probability .
2. Mathematical Formulation and Detailed Operations
The GRU cell is formulated as:
Bidirectionality concatenates forward and backward hidden states. Layer normalization (LN) is applied after each residual summation.
In the channel reweighting block, after temporal pooling, gating weights are applied:
The MHSA pooling applies, for each layer :
After layers, the CLS vector is passed to a two-layer MLP.
3. Experimental Evaluation and Comparative Results
The model was evaluated via 5-fold stratified cross-validation on the UCI Heart Disease dataset, benchmarked against classical and deep learning baselines. Performance metrics included Accuracy, Macro-F1, ROC-AUC, and PR-AUC.
| Model | Accuracy ± std | Macro-F1 ± std | ROC-AUC ± std | PR-AUC ± std |
|---|---|---|---|---|
| Logistic Regression | 0.832 ± 0.050 | 0.830 ± 0.050 | 0.912 ± 0.018 | 0.908 ± 0.035 |
| GaussianNB | 0.855 ± 0.040 | 0.853 ± 0.040 | 0.902 ± 0.023 | 0.895 ± 0.032 |
| Stacked GRU | 0.851 ± 0.038 | 0.848 ± 0.039 | 0.897 ± 0.021 | 0.899 ± 0.033 |
| Deep Transformer | 0.848 ± 0.042 | 0.847 ± 0.042 | 0.909 ± 0.027 | 0.912 ± 0.031 |
| LSTM–Transformer hybrid | 0.858 ± 0.043 | 0.856 ± 0.044 | 0.907 ± 0.020 | 0.907 ± 0.020 |
| Residual GRU + MHSA | 0.861 ± 0.032 | 0.860 ± 0.032 | 0.908 ± 0.022 | 0.904 ± 0.027 |
Residual GRU+MHSA achieves the best trade-off between classification (Accuracy, Macro-F1) and ranking (ROC-AUC, PR-AUC) metrics, outperforming both classical approaches and contemporary deep learning models (Dash et al., 16 Dec 2025).
4. Ablation Analysis and Component Contributions
Ablation studies systematically removed or altered single components, revealing the following:
- Eliminating the channel reweighting module (CR) yielded marginal gains, indicating limited channel redundancy in small tabular data.
- Replacing MHSA pooling with mean+max pooling reduced ROC-AUC by ≈1.3 percentage points, substantiating attention's role in global context aggregation.
- Removing residual stacking or GRU bidirectionality significantly degraded accuracy and ranking metrics, confirming their necessity for robust generalization.
- Feature dropout and deeper MLP heads improved generalization and PR-AUC.
Key ablation variants and results are summarized:
| Variant | Accuracy ± std | F1 (Macro) | ROC-AUC | PR-AUC |
|---|---|---|---|---|
| Residual GRU + MHSA (full) | 0.861 ± 0.037 | 0.859 | 0.904 | 0.908 |
| No Channel Reweighting | 0.865 ± 0.032 | 0.862 | 0.909 | 0.911 |
| No MHSA (mean+max pooling) | 0.859 ± 0.031 | 0.851 | 0.891 | 0.899 |
| No Residual Stack (N=0) | 0.855 ± 0.048 | 0.852 | 0.897 | 0.895 |
| Unidirectional GRU | 0.841 ± 0.029 | 0.849 | 0.891 | 0.896 |
5. Computational Efficiency and Deployment Properties
The model is highly compact:
- Parameter count: Less than 1.2 million total parameters (≈4.8 MB float32, ≲1.2 MB with 8-bit quantization).
- Module breakdown: BiGRU stack ≈0.8M, MHSA ≈0.3M, CR+MLP ≈0.1M.
- Complexity: Recurrent stack has time complexity ; attention pooling for , , , .
- Inference: Real-time inference is feasible: single-sample latency ≤5 ms on mid-range CPU, ≤1 ms on GPU.
- Resource deployment: Model fits within ≲2 MB post-quantization, suitable for microcontrollers and edge devices, with low memory bandwidth due to small .
A plausible implication is that this architecture enables robust deployment in resource-constrained healthcare settings, including wearable monitors and on-device clinical screening tools.
6. Embedding Structure and Generalization
Empirical t-SNE visualizations of latent representations indicate the model produces embedding spaces with clearer separability between disease and non-disease classes compared to raw features (Dash et al., 16 Dec 2025). This suggests that the hybrid recurrent-attention approach efficiently captures complex, non-linear feature relationships in tabular healthcare data, overcoming limitations of models relying exclusively on handcrafted features or purely sequential/transformer-based approaches.
7. Significance and Future Directions
Residual GRU+MHSA demonstrates that integrating residual recurrence with multi-head attention yields compact yet robust architectures for tabular risk prediction. Ablation confirms that bidirectionality, residual stacking, and MHSA pooling are critical to generalization. The model’s low memory footprint, fast inference, and strong performance across validation splits support its suitability for healthcare deployment where accuracy and efficiency are both requirements (Dash et al., 16 Dec 2025). Future exploration could address scaling the architecture to higher-dimensional data, integrating longitudinal records, and generalizing to multiclass or multilabel diseases in diverse clinical contexts.