TwinSegNet: Federated Brain Tumor Segmentation
- TwinSegNet is a federated framework that enables privacy-preserving brain tumor segmentation across multiple clinical sites using client-personalized digital twins.
- It employs a hybrid ViT-UNet architecture combining convolutional encoders for local spatial detail with a transformer-based bottleneck for global context modeling.
- Experimental evaluations demonstrate superior Dice scores and robustness on heterogeneous, non-IID 3D MRI datasets compared to centralized segmentation models.
TwinSegNet is a federated learning framework designed for privacy-preserving, real-time brain tumor segmentation across multiple clinical institutions. It integrates a hybrid architecture—combining convolutional neural networks (CNNs) with 3D Vision Transformer (ViT) components—against the backdrop of highly heterogeneous, non-IID 3D MRI data. Central to TwinSegNet's methodology is the concept of a client-personalized “digital twin”, a model instance locally fine-tuned at each site. This supports both collective knowledge sharing and institution-specific adaptation, thereby enabling scalable segmentation workflows that adhere to strict data confidentiality constraints. Experimental results on major public and custom MRI cohorts demonstrate that TwinSegNet maintains SOTA segmentation fidelity while ensuring robust client privacy (Wakili et al., 19 Dec 2025).
1. Hybrid ViT-UNet Model Architecture
TwinSegNet adopts a modified UNet structure featuring an encoder–bottleneck–decoder topology, where the central bottleneck leverages transformer-based global context modeling. The input, comprised of four resampled MRI modalities (T1, T1ce, T2, FLAIR) at resolution, is processed as follows:
- Convolutional Encoder: Four downsampling stages, each employing two Conv3D layers (with BatchNorm3D and ReLU activations), followed by max-pooling. Channel widths double at each stage (e.g., 32→64→128→256).
- ViT Bottleneck: The final encoder output () is reshaped into 3D patches, yielding tokens of dimensions each, linearly projected to . Positional embeddings are added, followed by four transformer encoder layers utilizing 8-head self-attention.
- Convolutional Decoder: Symmetric to the encoder, each upsampling stage uses a transposed convolution (stride=2), skip-connects features from the encoding path, and applies Conv3D + BatchNorm3D + ReLU blocks (channels halve each stage).
- Final Output: A Conv3D projects to four voxel classes (background, edema, tumor core, enhancing tumor), followed by a softmax layer.
The core forward pass is summarized in the following pseudocode:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
def TwinSegNetForward(x): # x: [B, 4, 128,128,128] skips = [] for i in range(4): x = Conv3D_BN_ReLU(x) skips.append(x) x = MaxPool3D(x) tokens = PatchEmbed3D(patch_size=2)(x) # [B, 64, 512] tokens += PositionalEmbedding for _ in range(4): tokens = tokens + MSA(LayerNorm(tokens)) tokens = tokens + MLP(LayerNorm(tokens)) x = Unpatch(tokens) for j in range(4): x = ConvTranspose3D(x) x = Concat(x, skips[3-j]) x = Conv3D_BN_ReLU(x) out = Conv3D_1x1x1(x) return Softmax(out) |
This architecture combines efficient local spatial detail extraction with global context modeling crucial for pathological pattern recognition in 3D MRI.
2. Federated Learning Workflow
The federated setup encompasses clients (hospitals), each with an institution-specific MRI dataset. Synchronous federated rounds operate as follows:
- Initialization: Global weights distributed to clients.
- Local Training: Each client updates via epochs on local data , obtaining .
- Aggregation: Server computes
where , .
- Communication: Only model weights are exchanged; raw images remain strictly local.
After aggregation, each client locally fine-tunes the global model to create its digital twin for improved site-specific performance. This process is decentralized and does not trigger further inter-client communication.
3. Digital Twin Personalization
“Digital twins” in this context are site-specific model copies, initialized with aggregated parameters and further fine-tuned on each client's validation split for –2 epochs. Personalization employs a composite loss:
- Cross-Entropy:
- Dice Loss:
- Total Loss: (with )
Optimization uses Adam (lr=), optionally incorporating regularization to mitigate overfitting risks. Digital twins are intended to reflect each institution's case-mix and imaging distribution, thereby improving the generalization and utility of the deployed model within each hospital.
4. Experimental Evaluation
TwinSegNet was evaluated on nine MRI corpora, including BraTS 2019–2021 and five custom collections (Glioma, Meningioma, Metastatic, Pediatric, Secondary), spanning client scales from 60 to 1,251 cases. Each client's data utilized a 70%/15%/15% train/validation/test split. Core results include:
| Model | Dice WT | Dice TC | Dice ET | Sens. | Spec. |
|---|---|---|---|---|---|
| TwinSegNet Global | 0.931 ± 0.03 | 0.851 ± 0.05 | 0.844 ± 0.04 | 0.91 ± 0.02 | 0.93 ± 0.02 |
| TwinSegNet DT | 0.940 ± 0.02 | 0.881 ± 0.03 | 0.864 ± 0.03 | — | — |
Compared to centralized SOTA models—TransUNet (mean Dice: 0.86) and nn-U-Net variants (0.87–0.88)—TwinSegNet achieves superior average Dice when personalized digital twins are employed (mean Dice up to 0.897). Personalized digital twins yield an average gain of +1.8% Dice and +1.6% IoU over the federated global baseline. ROC AUCs per class range from 0.80 (TC) to 0.95 (ET), with consistently improved per-client performance on small and skewed datasets.
5. Implementation and Privacy Considerations
TwinSegNet experiments were conducted in PyTorch 1.13 environments on NVIDIA RTX3080 (8 GB), Intel i9 CPUs, and 64 GB RAM. The federated learning simulation used logical isolation for each of the nine clients and one central server. Key hyperparameters include rounds, batch size = 2, Adam optimizer (, weight decay ). Preprocessing leveraged TorchIO for intensity normalization, resampling, and online data augmentation.
Regarding security, the framework ensures that no raw images or clinical features leave client premises. Only weight updates are communicated, with secure channels (e.g., TLS) assumed. The framework is compatible with future privacy enhancements such as secure aggregation and differential privacy, although neither were implemented in the present study.
6. Future Prospects and Applications
TwinSegNet's hybrid ViT-UNet backbone facilitates a balance between local feature extraction and global volumetric context, supporting high-fidelity segmentations in heterogeneous clinical MRI. The federated approach is demonstrably robust to non-IID data distributions and preserves privacy without discernible loss in performance compared to centralized baselines. Digital twins enable targeted improvements with negligible overhead through lightweight fine-tuning.
A plausible implication is that deploying TwinSegNet in real-world clinical networks could yield scalable, privacy-compliant, and personalized segmentation across diverse hospital environments. Recommended directions for further work include adaptive communication protocols (partial participation, FedProx), integration of differential privacy, field deployment studies under practical network constraints, and extending the methodology to longitudinal imaging, multi-modality fusion (e.g., PET/MRI), and closed-loop clinical feedback systems (Wakili et al., 19 Dec 2025).