Papers
Topics
Authors
Recent
2000 character limit reached

BertsWin: Efficient 3D Pre-training Architecture

Updated 1 January 2026
  • BertsWin is a hybrid self-supervised pre-training architecture that mitigates topological sparsity by preserving a complete 3D token grid using BERT-style masking.
  • It employs Swin Transformer windowed attention to limit computational complexity from O(N²) to O(N) while maintaining local spatial context.
  • The design integrates a 3D CNN stem, a multi-component structural loss, and a GradientConductor optimizer to achieve rapid semantic convergence and resource efficiency.

BertsWin is a hybrid self-supervised pre-training architecture for three-dimensional volumetric data that resolves the topological sparsity encountered by conventional Masked Autoencoders (MAEs) when applied to 3D structures. It integrates a full BERT-style token masking strategy with Swin Transformer windowed attention to simultaneously maintain spatial topology and computational efficiency. BertsWin was introduced to address the structural discontinuities and slow convergence inherent to 3D MAE pre-training, and achieves substantial improvements in both semantic convergence speed and resource utilization, as demonstrated on cone beam computed tomography (CBCT) of the temporomandibular joint (TMJ) (Limarenko et al., 25 Dec 2025).

1. Motivation for BertsWin Pre-training

While state-of-the-art 2D masked autoencoders (MAEs) exploit the redundancy of planar images, masking 75% of 3D patches in volumetric inputs produces severe topological sparsity. This disrupts anatomical continuity, fragmenting spatially connected regions and destroying geometric priors. The resulting fragmentation leads to blocking artifacts and loss of context, with networks forced to reconstruct entire structures from sparse, disconnected tokens, which significantly slows convergence rates. Computationally, expanding MAE to full 3D attention incurs a prohibitive O(N2)O(N^2) complexity, where NN is the number of volumetric patches.

BertsWin provides two key innovations: (a) it maintains a complete 3D token grid (mask tokens plus visible embeddings), preserving spatial topology throughout the encoder; (b) it replaces global attention with Swin Transformer windows, limiting computation to O(N)O(N) complexity while retaining effective local spatial context. This approach targets rapid, resource-efficient learning of structural priors in 3D self-supervised pre-training (Limarenko et al., 25 Dec 2025).

2. High-Level Architecture

BertsWin comprises four distinct modules in a sequential arrangement:

A. 3D CNN Stem (Hybrid Patch Embedding):

  • Input volume VRB×D×H×WV\in\mathbb{R}^{B\times D\times H\times W} is split into non-overlapping P3P^3-sized patches, yielding N=(D/P)(H/P)(W/P)N=(D/P)\cdot(H/P)\cdot(W/P) total tokens.
  • 25% of patches are selected at random and passed through four 3D convolutional blocks to produce visible token embeddings evisRB×nvis×C\mathbf{e}_\text{vis}\in\mathbb{R}^{B\times n_\text{vis}\times C}.

B. Full 3D Token Grid with Positional Embedding:

  • A binary mask M{0,1}B×NM\in\{0,1\}^{B\times N} records the visibility of each patch.
  • Visible embeddings are scattered to their locations in the grid; masked locations are filled with a learnable token emaskRC\mathbf{e}_\text{mask}\in\mathbb{R}^C.
  • Fixed or learnable positional embeddings of shape (N,C)(N, C) are added for spatial context.

C. Single-Scale Swin Transformer Encoder:

  • Consists of 12 Swin blocks, each performing window-based local attention with window size w×w×ww\times w\times w (typically w=7w=7), shift patterns for inter-window information transfer, 12 attention heads, and no downsampling.
  • The output is a feature grid ZRB×N×C\mathbf{Z}\in\mathbb{R}^{B\times N\times C}.

D. 3D CNN Decoder:

  • Three transposed convolution layers (strides 4, 2, 2) upsample the features to reconstruct the full volumetric input V^RB×D×H×W\hat{V}\in\mathbb{R}^{B\times D\times H\times W}.
Module Input Shape Key Parameters
CNN Stem B×D×H×WB\times D\times H\times W 4 blocks, patch size P3P^3, stride 2
Token Grid + Positional Embed. B×NB\times N Binary mask MM, emask\mathbf{e}_\text{mask}
Swin Transformer Encoder B×N×CB\times N\times C 12 blocks, window size w=7w=7, 12 heads
CNN Decoder Z\mathbf{Z} 3 transposed-conv layers (strides 4,2,2)

3. 3D BERT-style Masking Mechanism

BertsWin employs 3D patch-level masking, generalizing BERT-style masking to volumetric data. For NN patches and masking ratio r=0.75r=0.75, the binary mask Mb,nM_{b,n} for sample bb, patch nn is defined as:

$M_{b,n} = \begin{cases} 1 & \text{if patch $ninsample in sample b$ is visible} \ 0 & \text{if patch $n$ is masked} \end{cases}$

Token construction proceeds via:

Tb,n=Mb,nE(xb,n)+(1Mb,n)EmaskT_{b,n} = M_{b,n}\,E(x_{b,n}) + (1-M_{b,n})\,E_\text{mask}

where xb,nx_{b,n} denotes voxel values in patch nn, E()E(\cdot) denotes the CNN stem embedding, and EmaskE_\text{mask} is the learnable embedding for masked positions. Positional embeddings PnP_n are added:

T~b,n=Tb,n+Pn\tilde{T}_{b,n} = T_{b,n} + P_n

This complete grid, containing all visible and masked patches, is propagated through the encoder. By preserving the 3D topology during masking and processing, the architecture maintains anatomical coherence and improves convergence dynamics (Limarenko et al., 25 Dec 2025).

4. Structural Priority Loss: Multi-Component Variance and PhysLoss

BertsWin introduces a structural loss decomposing per-patch mean squared error (MSE) into three components—brightness (LBrL_\text{Br}), contrast (LCntrL_\text{Cntr}), and structure (LStrL_\text{Str})—defined per patch pair (X,Y)(X,Y):

  • Brightness: LBr=(μXμY)2L_\mathrm{Br} = (\mu_X - \mu_Y)^2
  • Contrast: LCntr=(σXσY)2L_\mathrm{Cntr} = (\sigma_X - \sigma_Y)^2
  • Structure: LStr=2σXσY(1ρ(X,Y))L_\mathrm{Str} = 2\,\sigma_X\,\sigma_Y\,(1 - \rho(X,Y))\ where μX\mu_X is the patch mean, σX\sigma_X is the patch standard deviation, and ρ(X,Y)\rho(X,Y) is patchwise correlation.

The Multi-Component Variance loss is:

LMVC=wBrLBr+wCntrLCntr+wStrLStrL_\mathrm{MVC} = w_\mathrm{Br} L_\mathrm{Br} + w_\mathrm{Cntr} L_\mathrm{Cntr} + w_\mathrm{Str} L_\mathrm{Str}

with weights wBr=0.3w_\mathrm{Br}=0.3, wCntr=0.2w_\mathrm{Cntr}=0.2, wStr=0.5w_\mathrm{Str}=0.5.

PhysLoss further prioritizes structurally critical regions by computing LMVCL_\mathrm{MVC} over three domains:

  • Global patch domain Ω\Omega,
  • Soft-tissue mask MsoftM_\text{soft},
  • Bone-surface shell mask MsurfM_\text{surf}.

The complete loss function is:

LPhysLoss=λglobalLMVC(Ω)+λsoftLMVC(Msoft)+λsurfLMVC(Msurf)L_\mathrm{PhysLoss} = \lambda_\text{global}\,L_\mathrm{MVC}(\Omega) + \lambda_\text{soft}\,L_\mathrm{MVC}(M_\text{soft}) + \lambda_\text{surf}\,L_\mathrm{MVC}(M_\text{surf})

with coefficients λglobal=0.3\lambda_\text{global}=0.3, λsoft=0.5\lambda_\text{soft}=0.5, λsurf=0.2\lambda_\text{surf}=0.2. This multi-component loss supports anatomically aware learning and accelerates semantic convergence (Limarenko et al., 25 Dec 2025).

5. GradientConductor (GCond) Optimizer

GradientConductor (GCond) is a custom optimizer combining features of LION (sign updates), LARS (trust-ratio scaling), and Adam (bias correction). Parameter ptp_t is updated as follows:

  1. First-moment estimate (momentum) with bias correction:

mt=β1mt1+(1β1)gt,m^t=mt1β1tm_t = \beta_1 m_{t-1} + (1-\beta_1)g_t,\quad \hat{m}_t = \frac{m_t}{1-\beta_1^t}

  1. Trust ratio scaling (LARS):

λt=min(pt1m^t+ε,λclip)\lambda_t = \min \left(\frac{\|p_{t-1}\|}{\|\hat m_t\| + \varepsilon},\,\lambda_\text{clip} \right)

  1. Parameter update using sign of first moment (LION-style):

pt=pt1ηγλtsign(m^t)p_t = p_{t-1} - \eta\gamma\lambda_t\mathrm{sign}(\hat m_t)

Where β1=0.9\beta_1=0.9, ε=106\varepsilon=10^{-6}, λclip=10\lambda_\text{clip}=10, and effective learning rate ηγ=1.5×105\eta\gamma=1.5\times 10^{-5}. Only the first moment mtm_t is stored, reducing optimizer memory by ∼50% relative to AdamW. The optimizer confers stable warm-up (bias correction), memory efficiency, and cross-layer gradient scaling (Limarenko et al., 25 Dec 2025).

6. Computational Complexity and Empirical Convergence

FLOP Analysis

For input resolution 2243224^3 and patch size P=16P=16 (thus N=2744N=2744):

  • BertsWin: encoder $125.2$ GFLOPs, stem $81.7$ GFLOPs, decoder $16.9$ GFLOPs, total $223.8$ GFLOPs.
  • MONAI ViT-MAE: encoder $134.1$, stem $17.3$, decoder $76.9$, total $228.3$ GFLOPs.

At 5123512^3 resolution:

  • BertsWin: $2673.1$ GFLOPs (linear in NN).
  • ViT baseline: $11035.5$ GFLOPs (O(N2)O(N^2)).

BertsWin therefore achieves theoretical FLOP parity at standard resolution and a ~4.1× reduction at high resolution, attributed to its window-based attention.

Convergence Speed

  • Monai MAE ViT (L2+AdamW): 660 epochs to best validation MSE.
  • BertsWin (L2+AdamW): 114 epochs, yielding a semantic speedup of 5.8×5.8 \times.
  • BertsWin + PhysLoss + GCond: 44 epochs, providing a total speedup of 15×15\times.
  • Due to GFLOPs parity per epoch, GPU-hour requirements drop by the same factors, offering substantial practical acceleration (Limarenko et al., 25 Dec 2025).

7. Summary and Context

BertsWin resolves topological sparsity in masked-pretraining for 3D volumes by restoring a complete token grid, leveraging Swin-style local attention for scalable computation, and implementing a structural loss with anatomical prioritization. Combined with a memory-efficient custom optimizer, the architecture provides efficient training and rapid convergence, achieving up to 15×15\times epoch reduction with maintained or reduced per-iteration FLOPs. These innovations are empirically validated on 3D TMJ CT segmentation, addressing both computational and topological bottlenecks in 3D self-supervised learning (Limarenko et al., 25 Dec 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Whiteboard

Topic to Video (Beta)

Follow Topic

Get notified by email when new papers are published related to BertsWin Architecture.

Don't miss out on important new AI/ML research

See which papers are being discussed right now on X, Reddit, and more:

“Emergent Mind helps me see which AI papers have caught fire online.”

Philip

Philip

Creator, AI Explained on YouTube