Papers
Topics
Authors
Recent
Search
2000 character limit reached

GraviBERT: Deep Learning for GW Inference

Updated 29 December 2025
  • GraviBERT is a deep learning framework that uses masked pretraining followed by supervised fine-tuning to accurately infer key binary black hole parameters from gravitational-wave signals.
  • It integrates an Inception-inspired convolutional feature extractor with a transformer encoder to process noisy time series, achieving faster convergence and significant reductions in MAE.
  • GraviBERT demonstrates robust domain adaptation and cross-approximant generalization, establishing a foundation for advanced multi-messenger gravitational-wave analyses.

GraviBERT is a deep-learning framework for direct inference of binary–black-hole gravitational-wave (GW) time series. It leverages a hybrid neural network architecture combining an Inception-inspired multi-scale convolutional frontend, a transformer encoder, and a regression head. The model is trained in two principal stages: an unsupervised BERT-style pretraining phase, which promotes the acquisition of universal physical patterns by reconstructing masked segments in feature space, followed by supervised fine-tuning for precise parameter estimation. Trained directly on noisy GW waveforms, GraviBERT achieves significant improvements in inference accuracy, convergence speed, and domain adaptation relative to models trained from scratch, and demonstrates adaptability to new detector environments and waveform models (Benedikt et al., 24 Dec 2025).

1. Model Architecture and Components

GraviBERT processes input GW strain time series xR4096x \in \mathbb{R}^{4096} (zero-padded, rescaled by 102110^{21}) through the following sequence:

  • Convolutional Feature Extractor (F\mathcal{F}):
    • Initial 1D convolution (kernel size 7, 64 channels), batch normalization, ReLU.
    • Three “main blocks,” each with three InceptionModules connected sequentially via residual skip, followed by ReLU and max pooling (kernel 2, stride 2), resulting in 8×8\times temporal downsampling (N=512N=512).
    • Each InceptionModule consists of a 1×11 \times 1 bottleneck convolution, three parallel convolutions (kernel sizes $10, 20, 40$; output channels per block $32, 48, 64$), a parallel max pool (kernel 3), 1×11 \times 1 convolution, concatenation along channels, and batch normalization.
    • Output: latent feature tensor LRB×512×64L \in \mathbb{R}^{B \times 512 \times 64}.
  • Transformer Encoder (102110^{21}0):
    • Linear projection to 256-dimensional embedding plus bias.
    • Prepending a [CLS] token (sequence length 102110^{21}1), sinusoidal positional encodings.
    • Four identical transformer layers (4 heads, feed-forward hidden size 512, GELU, dropout 0.1, LayerNorm, residuals).
    • Aggregation by taking the [CLS] token output 102110^{21}2.
  • Regression MLP Head (102110^{21}3):
    • Three dense layers: 102110^{21}4 (no activation in output), GELU activations and dropout (0.1) in hidden layers.
    • Outputs: point estimates of source parameters 102110^{21}5.
Component Dimensions/Params Function
Input 102110^{21}6 GW strain time series
Feature extractor See above Latent sequence extraction
Transformer encoder 4 layers, 102110^{21}7 Contextual aggregation/self-attention
Regression head 3 layers, last layer size 4 Source parameter estimation

2. Training Methodology: Pretraining and Fine-Tuning

Training proceeds in two main stages:

  • BERT-style Pretraining:
    • Contiguous spans of the time series are masked (zeroed). Masks are generated from a two-state Markov chain with mean segment length 102110^{21}8 and masking ratio 102110^{21}9.
    • The masked input F\mathcal{F}0 is encoded via F\mathcal{F}1 and F\mathcal{F}2. A prediction head attempts to reconstruct the projected latent representations at masked positions:

    F\mathcal{F}3 - This objective imposes long-range, multi-scale dependency learning and facilitates the encoding of universal inspiral–merger–ringdown waveform features.

  • Supervised Fine-Tuning:

    • Regression targets are F\mathcal{F}4, each normalized via min–max scaling. The mean squared error (MSE) loss is minimized:

    F\mathcal{F}5 - Prediction uncertainty is estimated via Monte Carlo dropout with F\mathcal{F}6 stochastic forward passes.

3. Experimental Setup and Hyperparameters

Data generation involves sampling parameters from a structured grid:

  • Primary/secondary masses F\mathcal{F}7.

  • Dimensionless spins F\mathcal{F}8.

  • Distances scaled to fixed SNR F\mathcal{F}9.

  • ET-B power spectral density, colored Gaussian noise, 8×8\times0 Hz, sample rate 4096 Hz.

  • Data split: 8×8\times1 train, 8×8\times2 validation, 8×8\times3 test.

Training details:

  • Pretraining: AdamW, learning rate 8×8\times4, no weight decay, cosine-annealing LR schedule with 8×8\times5 warmup, batch size 8×8\times6, up to 200 epochs.

  • Fine-tuning: AdamW, learning rate 8×8\times7, weight decay 8×8\times8, batch size 8×8\times9, up to 120 epochs, gradient clipping at 1.0.

4. Quantitative Performance and Evaluation

On in-domain ET-B test data, pretraining followed by fine-tuning yields marked improvements:

  • Convergence speed: Up to N=512N=5120 faster (small dataset; SNR N=512N=5121), N=512N=5122 (SNR N=512N=5123).

  • MAE reductions:

    • Up to N=512N=5124 (e.g., large set, SNR N=512N=5125: N=512N=5126 N=512N=5127, N=512N=5128 N=512N=5129, 1×11 \times 10 1×11 \times 11, 1×11 \times 12 1×11 \times 13).
    • For small, SNR 1×11 \times 14: 1×11 \times 15 1×11 \times 16, 1×11 \times 17 1×11 \times 18, 1×11 \times 19 $10, 20, 40$0, $10, 20, 40$1 $10, 20, 40$2.
  • Absolute precision at SNR $10, 20, 40$3 (large dataset):
    • Mean relative errors: $10, 20, 40$4, $10, 20, 40$5, $10, 20, 40$6.
    • Effective spin MAE $10, 20, 40$7.
  • $10, 20, 40$8 scores: All $10, 20, 40$9 for masses/distances, $32, 48, 64$0.
  • Scaling law: Test loss $32, 48, 64$1 with $32, 48, 64$2 ($32, 48, 64$3), $32, 48, 64$4 ($32, 48, 64$5).

5. Domain Adaptation and Cross-Approximant Transfer

GraviBERT demonstrates efficient adaptation to new noise and waveform models:

  • ET-B $32, 48, 64$6 LIGO–Virgo (LV) domain adaptation:
    • Fine-tuning on small LV sets ($32, 48, 64$7k–$32, 48, 64$8k samples) achieves up to $32, 48, 64$9 MAE reduction for 1×11 \times 10 (SNR1×11 \times 11) and 1×11 \times 12 convergence acceleration.
    • At 1×11 \times 13k samples, SNR 1×11 \times 14: 1×11 \times 15 MAE drops 1×11 \times 16, 1×11 \times 17 1×11 \times 18, 1×11 \times 19 LRB×512×64L \in \mathbb{R}^{B \times 512 \times 64}0, LRB×512×64L \in \mathbb{R}^{B \times 512 \times 64}1 LRB×512×64L \in \mathbb{R}^{B \times 512 \times 64}2.
  • Cross-approximant generalization (SEOBNRv4 LRB×512×64L \in \mathbb{R}^{B \times 512 \times 64}3 IMRPhenomD):
    • Zero-shot inference retains LRB×512×64L \in \mathbb{R}^{B \times 512 \times 64}4 for LRB×512×64L \in \mathbb{R}^{B \times 512 \times 64}5, LRB×512×64L \in \mathbb{R}^{B \times 512 \times 64}6, LRB×512×64L \in \mathbb{R}^{B \times 512 \times 64}7 (but degradation for LRB×512×64L \in \mathbb{R}^{B \times 512 \times 64}8; LRB×512×64L \in \mathbb{R}^{B \times 512 \times 64}9).
    • Fine-tuning on 102110^{21}00k IMRPhenomD events yields up to 102110^{21}01 MAE reduction for 102110^{21}02, 102110^{21}03–102110^{21}04 faster convergence, and 102110^{21}05 for masses even at SNR 102110^{21}06 (versus 102110^{21}07–102110^{21}08 scratch).

These results indicate detector- and waveform-agnostic latent representations, enabling rapid transfer across noise environments and waveform models.

6. Foundation-Model Perspective and Downstream Applications

GraviBERT’s pretrained encoder and transformer backbone are adaptable to various downstream tasks by substituting or augmenting the regression head. Downstream application domains include:

  • Overlapping signal separation
  • Anomaly detection
  • Multi-messenger joint inference
  • Tests of general relativity
  • Searches for new physics

Current limitations include:

  • Point estimation only (no posterior inference)
  • Four intrinsic parameters (102110^{21}09, 102110^{21}10, 102110^{21}11, 102110^{21}12)
  • Single-channel analysis
  • No architecture hyperparameter search

Planned future enhancements are full 15-parameter inference (including extrinsic and sky-location parameters), multi-channel/multi-detector extensions, low-frequency ET operation (down to 102110^{21}13 Hz), and optimized transformer variants.

7. Mathematical Formulations and Metrics

Key formulas underpinning waveform modeling and training include:

  • Waveform Model (leading PN order):

102110^{21}14

  • Pretraining Loss:

102110^{21}15

  • Fine-tuning MSE Loss:

102110^{21}16

  • Evaluation Metrics:

102110^{21}17

A plausible implication is that the masked reconstruction pretraining confers the ability to encode universal GW properties in latent representations, supporting generalization across domains and waveforms.


GraviBERT establishes a new paradigm for foundation-style models in GW astrophysics, introducing masked-segment reconstruction for noisy time series and demonstrating substantial gains in parameter inference accuracy and efficiency, domain adaptation, and extensibility to broader multi-messenger inference pipelines (Benedikt et al., 24 Dec 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to GraviBERT.