GraviBERT: Deep Learning for GW Inference
- GraviBERT is a deep learning framework that uses masked pretraining followed by supervised fine-tuning to accurately infer key binary black hole parameters from gravitational-wave signals.
- It integrates an Inception-inspired convolutional feature extractor with a transformer encoder to process noisy time series, achieving faster convergence and significant reductions in MAE.
- GraviBERT demonstrates robust domain adaptation and cross-approximant generalization, establishing a foundation for advanced multi-messenger gravitational-wave analyses.
GraviBERT is a deep-learning framework for direct inference of binary–black-hole gravitational-wave (GW) time series. It leverages a hybrid neural network architecture combining an Inception-inspired multi-scale convolutional frontend, a transformer encoder, and a regression head. The model is trained in two principal stages: an unsupervised BERT-style pretraining phase, which promotes the acquisition of universal physical patterns by reconstructing masked segments in feature space, followed by supervised fine-tuning for precise parameter estimation. Trained directly on noisy GW waveforms, GraviBERT achieves significant improvements in inference accuracy, convergence speed, and domain adaptation relative to models trained from scratch, and demonstrates adaptability to new detector environments and waveform models (Benedikt et al., 24 Dec 2025).
1. Model Architecture and Components
GraviBERT processes input GW strain time series (zero-padded, rescaled by %%%%1%%%%) through the following sequence:
- Convolutional Feature Extractor ():
- Initial 1D convolution (kernel size 7, 64 channels), batch normalization, ReLU.
- Three “main blocks,” each with three InceptionModules connected sequentially via residual skip, followed by ReLU and max pooling (kernel 2, stride 2), resulting in temporal downsampling ().
- Each InceptionModule consists of a bottleneck convolution, three parallel convolutions (kernel sizes $10, 20, 40$; output channels per block $32, 48, 64$), a parallel max pool (kernel 3), convolution, concatenation along channels, and batch normalization.
- Output: latent feature tensor .
- Transformer Encoder ():
- Linear projection to 256-dimensional embedding plus bias.
- Prepending a [CLS] token (sequence length ), sinusoidal positional encodings.
- Four identical transformer layers (4 heads, feed-forward hidden size 512, GELU, dropout 0.1, LayerNorm, residuals).
- Aggregation by taking the [CLS] token output .
- Regression MLP Head ():
- Three dense layers: (no activation in output), GELU activations and dropout (0.1) in hidden layers.
- Outputs: point estimates of source parameters .
| Component | Dimensions/Params | Function |
|---|---|---|
| Input | GW strain time series | |
| Feature extractor | See above | Latent sequence extraction |
| Transformer encoder | 4 layers, | Contextual aggregation/self-attention |
| Regression head | 3 layers, last layer size 4 | Source parameter estimation |
2. Training Methodology: Pretraining and Fine-Tuning
Training proceeds in two main stages:
- BERT-style Pretraining:
- Contiguous spans of the time series are masked (zeroed). Masks are generated from a two-state Markov chain with mean segment length and masking ratio .
- The masked input is encoded via and . A prediction head attempts to reconstruct the projected latent representations at masked positions:
- This objective imposes long-range, multi-scale dependency learning and facilitates the encoding of universal inspiral–merger–ringdown waveform features.
-
- Regression targets are , each normalized via min–max scaling. The mean squared error (MSE) loss is minimized:
- Prediction uncertainty is estimated via Monte Carlo dropout with stochastic forward passes.
3. Experimental Setup and Hyperparameters
Data generation involves sampling parameters from a structured grid:
Primary/secondary masses .
Dimensionless spins .
Distances scaled to fixed SNR .
ET-B power spectral density, colored Gaussian noise, Hz, sample rate 4096 Hz.
Data split: train, validation, test.
Training details:
Pretraining: AdamW, learning rate , no weight decay, cosine-annealing LR schedule with warmup, batch size , up to 200 epochs.
Fine-tuning: AdamW, learning rate , weight decay , batch size , up to 120 epochs, gradient clipping at 1.0.
4. Quantitative Performance and Evaluation
On in-domain ET-B test data, pretraining followed by fine-tuning yields marked improvements:
Convergence speed: Up to faster (small dataset; SNR ), (SNR ).
MAE reductions:
- Up to (e.g., large set, SNR : , , , ).
- For small, SNR : , , , .
- Absolute precision at SNR (large dataset):
- Mean relative errors: , , .
- Effective spin MAE .
- scores: All for masses/distances, .
- Scaling law: Test loss with (), $0.36$ ().
5. Domain Adaptation and Cross-Approximant Transfer
GraviBERT demonstrates efficient adaptation to new noise and waveform models:
- ET-B LIGO–Virgo (LV) domain adaptation:
- Fine-tuning on small LV sets ($5$k–$10$k samples) achieves up to MAE reduction for (SNR) and convergence acceleration.
- At $5$k samples, SNR : MAE drops , , , .
- Cross-approximant generalization (SEOBNRv4 IMRPhenomD):
- Zero-shot inference retains for , , (but degradation for ; ).
- Fine-tuning on $5$k IMRPhenomD events yields up to MAE reduction for , $6$– faster convergence, and for masses even at SNR (versus $0.74$–$0.87$ scratch).
These results indicate detector- and waveform-agnostic latent representations, enabling rapid transfer across noise environments and waveform models.
6. Foundation-Model Perspective and Downstream Applications
GraviBERT’s pretrained encoder and transformer backbone are adaptable to various downstream tasks by substituting or augmenting the regression head. Downstream application domains include:
- Overlapping signal separation
- Anomaly detection
- Multi-messenger joint inference
- Tests of general relativity
- Searches for new physics
Current limitations include:
- Point estimation only (no posterior inference)
- Four intrinsic parameters (, , , )
- Single-channel analysis
- No architecture hyperparameter search
Planned future enhancements are full 15-parameter inference (including extrinsic and sky-location parameters), multi-channel/multi-detector extensions, low-frequency ET operation (down to $5$ Hz), and optimized transformer variants.
7. Mathematical Formulations and Metrics
Key formulas underpinning waveform modeling and training include:
- Waveform Model (leading PN order):
- Pretraining Loss:
$\mathcal{L}_\text{pretrain} = \frac{1}{|I|}\sum_{(b,i)\in I}\bigl\|\mathbf{z}_\text{target,(b,i,·)} - \mathbf{z}_\text{pred,(b,i,·)}\bigr\|_2^2$
- Fine-tuning MSE Loss:
- Evaluation Metrics:
A plausible implication is that the masked reconstruction pretraining confers the ability to encode universal GW properties in latent representations, supporting generalization across domains and waveforms.
GraviBERT establishes a new paradigm for foundation-style models in GW astrophysics, introducing masked-segment reconstruction for noisy time series and demonstrating substantial gains in parameter inference accuracy and efficiency, domain adaptation, and extensibility to broader multi-messenger inference pipelines (Benedikt et al., 24 Dec 2025).