BertsWin: Efficient 3D Pre-training Architecture
- BertsWin is a hybrid self-supervised pre-training architecture that mitigates topological sparsity by preserving a complete 3D token grid using BERT-style masking.
- It employs Swin Transformer windowed attention to limit computational complexity from O(N²) to O(N) while maintaining local spatial context.
- The design integrates a 3D CNN stem, a multi-component structural loss, and a GradientConductor optimizer to achieve rapid semantic convergence and resource efficiency.
BertsWin is a hybrid self-supervised pre-training architecture for three-dimensional volumetric data that resolves the topological sparsity encountered by conventional Masked Autoencoders (MAEs) when applied to 3D structures. It integrates a full BERT-style token masking strategy with Swin Transformer windowed attention to simultaneously maintain spatial topology and computational efficiency. BertsWin was introduced to address the structural discontinuities and slow convergence inherent to 3D MAE pre-training, and achieves substantial improvements in both semantic convergence speed and resource utilization, as demonstrated on cone beam computed tomography (CBCT) of the temporomandibular joint (TMJ) (Limarenko et al., 25 Dec 2025).
1. Motivation for BertsWin Pre-training
While state-of-the-art 2D masked autoencoders (MAEs) exploit the redundancy of planar images, masking 75% of 3D patches in volumetric inputs produces severe topological sparsity. This disrupts anatomical continuity, fragmenting spatially connected regions and destroying geometric priors. The resulting fragmentation leads to blocking artifacts and loss of context, with networks forced to reconstruct entire structures from sparse, disconnected tokens, which significantly slows convergence rates. Computationally, expanding MAE to full 3D attention incurs a prohibitive complexity, where is the number of volumetric patches.
BertsWin provides two key innovations: (a) it maintains a complete 3D token grid (mask tokens plus visible embeddings), preserving spatial topology throughout the encoder; (b) it replaces global attention with Swin Transformer windows, limiting computation to complexity while retaining effective local spatial context. This approach targets rapid, resource-efficient learning of structural priors in 3D self-supervised pre-training (Limarenko et al., 25 Dec 2025).
2. High-Level Architecture
BertsWin comprises four distinct modules in a sequential arrangement:
A. 3D CNN Stem (Hybrid Patch Embedding):
- Input volume is split into non-overlapping -sized patches, yielding total tokens.
- 25% of patches are selected at random and passed through four 3D convolutional blocks to produce visible token embeddings .
B. Full 3D Token Grid with Positional Embedding:
- A binary mask records the visibility of each patch.
- Visible embeddings are scattered to their locations in the grid; masked locations are filled with a learnable token .
- Fixed or learnable positional embeddings of shape are added for spatial context.
C. Single-Scale Swin Transformer Encoder:
- Consists of 12 Swin blocks, each performing window-based local attention with window size (typically ), shift patterns for inter-window information transfer, 12 attention heads, and no downsampling.
- The output is a feature grid .
D. 3D CNN Decoder:
- Three transposed convolution layers (strides 4, 2, 2) upsample the features to reconstruct the full volumetric input .
| Module | Input Shape | Key Parameters |
|---|---|---|
| CNN Stem | 4 blocks, patch size , stride 2 | |
| Token Grid + Positional Embed. | Binary mask , | |
| Swin Transformer Encoder | 12 blocks, window size , 12 heads | |
| CNN Decoder | 3 transposed-conv layers (strides 4,2,2) |
3. 3D BERT-style Masking Mechanism
BertsWin employs 3D patch-level masking, generalizing BERT-style masking to volumetric data. For patches and masking ratio , the binary mask for sample , patch is defined as:
$M_{b,n} = \begin{cases} 1 & \text{if patch $nb$ is visible} \ 0 & \text{if patch $n$ is masked} \end{cases}$
Token construction proceeds via:
where denotes voxel values in patch , denotes the CNN stem embedding, and is the learnable embedding for masked positions. Positional embeddings are added:
This complete grid, containing all visible and masked patches, is propagated through the encoder. By preserving the 3D topology during masking and processing, the architecture maintains anatomical coherence and improves convergence dynamics (Limarenko et al., 25 Dec 2025).
4. Structural Priority Loss: Multi-Component Variance and PhysLoss
BertsWin introduces a structural loss decomposing per-patch mean squared error (MSE) into three components—brightness (), contrast (), and structure ()—defined per patch pair :
- Brightness:
- Contrast:
- Structure: \ where is the patch mean, is the patch standard deviation, and is patchwise correlation.
The Multi-Component Variance loss is:
with weights , , .
PhysLoss further prioritizes structurally critical regions by computing over three domains:
- Global patch domain ,
- Soft-tissue mask ,
- Bone-surface shell mask .
The complete loss function is:
with coefficients , , . This multi-component loss supports anatomically aware learning and accelerates semantic convergence (Limarenko et al., 25 Dec 2025).
5. GradientConductor (GCond) Optimizer
GradientConductor (GCond) is a custom optimizer combining features of LION (sign updates), LARS (trust-ratio scaling), and Adam (bias correction). Parameter is updated as follows:
- First-moment estimate (momentum) with bias correction:
- Trust ratio scaling (LARS):
- Parameter update using sign of first moment (LION-style):
Where , , , and effective learning rate . Only the first moment is stored, reducing optimizer memory by ∼50% relative to AdamW. The optimizer confers stable warm-up (bias correction), memory efficiency, and cross-layer gradient scaling (Limarenko et al., 25 Dec 2025).
6. Computational Complexity and Empirical Convergence
FLOP Analysis
For input resolution and patch size (thus ):
- BertsWin: encoder $125.2$ GFLOPs, stem $81.7$ GFLOPs, decoder $16.9$ GFLOPs, total $223.8$ GFLOPs.
- MONAI ViT-MAE: encoder $134.1$, stem $17.3$, decoder $76.9$, total $228.3$ GFLOPs.
At resolution:
- BertsWin: $2673.1$ GFLOPs (linear in ).
- ViT baseline: $11035.5$ GFLOPs ().
BertsWin therefore achieves theoretical FLOP parity at standard resolution and a ~4.1× reduction at high resolution, attributed to its window-based attention.
Convergence Speed
- Monai MAE ViT (L2+AdamW): 660 epochs to best validation MSE.
- BertsWin (L2+AdamW): 114 epochs, yielding a semantic speedup of .
- BertsWin + PhysLoss + GCond: 44 epochs, providing a total speedup of .
- Due to GFLOPs parity per epoch, GPU-hour requirements drop by the same factors, offering substantial practical acceleration (Limarenko et al., 25 Dec 2025).
7. Summary and Context
BertsWin resolves topological sparsity in masked-pretraining for 3D volumes by restoring a complete token grid, leveraging Swin-style local attention for scalable computation, and implementing a structural loss with anatomical prioritization. Combined with a memory-efficient custom optimizer, the architecture provides efficient training and rapid convergence, achieving up to epoch reduction with maintained or reduced per-iteration FLOPs. These innovations are empirically validated on 3D TMJ CT segmentation, addressing both computational and topological bottlenecks in 3D self-supervised learning (Limarenko et al., 25 Dec 2025).