Papers
Topics
Authors
Recent
2000 character limit reached

GraviBERT: Deep Learning for GW Inference

Updated 29 December 2025
  • GraviBERT is a deep learning framework that uses masked pretraining followed by supervised fine-tuning to accurately infer key binary black hole parameters from gravitational-wave signals.
  • It integrates an Inception-inspired convolutional feature extractor with a transformer encoder to process noisy time series, achieving faster convergence and significant reductions in MAE.
  • GraviBERT demonstrates robust domain adaptation and cross-approximant generalization, establishing a foundation for advanced multi-messenger gravitational-wave analyses.

GraviBERT is a deep-learning framework for direct inference of binary–black-hole gravitational-wave (GW) time series. It leverages a hybrid neural network architecture combining an Inception-inspired multi-scale convolutional frontend, a transformer encoder, and a regression head. The model is trained in two principal stages: an unsupervised BERT-style pretraining phase, which promotes the acquisition of universal physical patterns by reconstructing masked segments in feature space, followed by supervised fine-tuning for precise parameter estimation. Trained directly on noisy GW waveforms, GraviBERT achieves significant improvements in inference accuracy, convergence speed, and domain adaptation relative to models trained from scratch, and demonstrates adaptability to new detector environments and waveform models (Benedikt et al., 24 Dec 2025).

1. Model Architecture and Components

GraviBERT processes input GW strain time series xR4096x \in \mathbb{R}^{4096} (zero-padded, rescaled by %%%%1%%%%) through the following sequence:

  • Convolutional Feature Extractor (F\mathcal{F}):
    • Initial 1D convolution (kernel size 7, 64 channels), batch normalization, ReLU.
    • Three “main blocks,” each with three InceptionModules connected sequentially via residual skip, followed by ReLU and max pooling (kernel 2, stride 2), resulting in 8×8\times temporal downsampling (N=512N=512).
    • Each InceptionModule consists of a 1×11 \times 1 bottleneck convolution, three parallel convolutions (kernel sizes $10, 20, 40$; output channels per block $32, 48, 64$), a parallel max pool (kernel 3), 1×11 \times 1 convolution, concatenation along channels, and batch normalization.
    • Output: latent feature tensor LRB×512×64L \in \mathbb{R}^{B \times 512 \times 64}.
  • Transformer Encoder (T\mathcal{T}):
    • Linear projection to 256-dimensional embedding plus bias.
    • Prepending a [CLS] token (sequence length N=513N'=513), sinusoidal positional encodings.
    • Four identical transformer layers (4 heads, feed-forward hidden size 512, GELU, dropout 0.1, LayerNorm, residuals).
    • Aggregation by taking the [CLS] token output EaggR256E_\text{agg} \in \mathbb{R}^{256}.
  • Regression MLP Head (H\mathcal{H}):
    • Three dense layers: 256256644256 \rightarrow 256 \rightarrow 64 \rightarrow 4 (no activation in output), GELU activations and dropout (0.1) in hidden layers.
    • Outputs: point estimates of source parameters [m^1,m^2,χ^eff,d^][\hat m_1, \hat m_2, \hat \chi_\text{eff}, \hat d].
Component Dimensions/Params Function
Input xR4096x \in \mathbb{R}^{4096} GW strain time series
Feature extractor See above Latent sequence extraction
Transformer encoder 4 layers, dmodel=256d_\text{model}=256 Contextual aggregation/self-attention
Regression head 3 layers, last layer size 4 Source parameter estimation

2. Training Methodology: Pretraining and Fine-Tuning

Training proceeds in two main stages:

  • BERT-style Pretraining:
    • Contiguous spans of the time series are masked (zeroed). Masks are generated from a two-state Markov chain with mean segment length lm=32l_m=32 and masking ratio r=0.15r=0.15.
    • The masked input xm=xmx_m = x \odot m is encoded via F\mathcal{F} and T\mathcal{T}. A prediction head attempts to reconstruct the projected latent representations at masked positions:

    Lpretrain=1I(b,i)Iztarget,(b,i,)zpred,(b,i,)22.\mathcal{L}_\text{pretrain} = \frac{1}{|I|}\sum_{(b,i)\in I}\left\| \mathbf{z}_{\text{target},(b,i,\cdot)} - \mathbf{z}_{\text{pred},(b,i,\cdot)} \right\|_2^2. - This objective imposes long-range, multi-scale dependency learning and facilitates the encoding of universal inspiral–merger–ringdown waveform features.

  • Supervised Fine-Tuning:

    • Regression targets are y=[m1,m2,χeff,d]y = [m_1, m_2, \chi_\text{eff}, d], each normalized via min–max scaling. The mean squared error (MSE) loss is minimized:

    LMSE(θ)=1Bi=1Byif(xi;θ)22.\mathcal{L}_\text{MSE}(\theta) = \frac{1}{B} \sum_{i=1}^B \| y_i - f(x_i;\theta) \|_2^2. - Prediction uncertainty is estimated via Monte Carlo dropout with T=100T=100 stochastic forward passes.

3. Experimental Setup and Hyperparameters

Data generation involves sampling parameters from a structured grid:

  • Primary/secondary masses m1,m2[5,150]Mm_1, m_2 \in [5,150]\,M_\odot.

  • Dimensionless spins s1,s2[0.4,0.4]χeffs_1,s_2 \in [-0.4,0.4] \rightarrow \chi_\text{eff}.

  • Distances scaled to fixed SNR {10,30}\in \{10,30\}.

  • ET-B power spectral density, colored Gaussian noise, flow=20f_\text{low}=20 Hz, sample rate 4096 Hz.

  • Data split: 80%80\% train, 10%10\% validation, 10%10\% test.

Training details:

  • Pretraining: AdamW, learning rate 1×1041 \times 10^{-4}, no weight decay, cosine-annealing LR schedule with 10%10\% warmup, batch size =250=250, up to 200 epochs.

  • Fine-tuning: AdamW, learning rate 5×1045 \times 10^{-4}, weight decay 1×1021 \times 10^{-2}, batch size =250=250, up to 120 epochs, gradient clipping at 1.0.

4. Quantitative Performance and Evaluation

On in-domain ET-B test data, pretraining followed by fine-tuning yields marked improvements:

  • Convergence speed: Up to 6.1×6.1\times faster (small dataset; SNR =10=10), 4.7×4.7\times (SNR =30=30).

  • MAE reductions:

    • Up to 30%30\% (e.g., large set, SNR =30=30: m1m_1 30%30\%, m2m_2 31%31\%, χeff\chi_\text{eff} 11%11\%, dd 26%26\%).
    • For small, SNR =10=10: m1m_1 8%8\%, m2m_2 6%6\%, χeff\chi_\text{eff} 14%14\%, dd 8%8\%.
  • Absolute precision at SNR =10=10 (large dataset):
    • Mean relative errors: m1:6.0%m_1:6.0\%, m2:3.4%m_2:3.4\%, d:8.3%d:8.3\%.
    • Effective spin MAE 3.8×103\simeq 3.8 \times 10^{-3}.
  • R2R^2 scores: All 0.90\geq 0.90 for masses/distances, χeff0.93\chi_\text{eff}\sim 0.93.
  • Scaling law: Test loss L(D)DαDL(D) \propto D^{-\alpha_D} with αD0.37\alpha_D \approx 0.37 (SNR=10\mathrm{SNR}=10), $0.36$ (SNR=30\mathrm{SNR}=30).

5. Domain Adaptation and Cross-Approximant Transfer

GraviBERT demonstrates efficient adaptation to new noise and waveform models:

  • ET-B \rightarrow LIGO–Virgo (LV) domain adaptation:
    • Fine-tuning on small LV sets ($5$k–$10$k samples) achieves up to 47%47\% MAE reduction for χeff\chi_\text{eff} (SNR=10=10) and 15×15\times convergence acceleration.
    • At $5$k samples, SNR =10=10: m1m_1 MAE drops 26%26\%, m2m_2 18%18\%, χeff\chi_\text{eff} 43%43\%, dd 25%25\%.
  • Cross-approximant generalization (SEOBNRv4 \rightarrow IMRPhenomD):
    • Zero-shot inference retains R2>0.80R^2>0.80 for m1m_1, m2m_2, dd (but degradation for χeff\chi_\text{eff}; R20.65R^2\approx 0.65).
    • Fine-tuning on $5$k IMRPhenomD events yields up to 44%44\% MAE reduction for χeff\chi_\text{eff}, $6$–15×15\times faster convergence, and R2>0.9R^2>0.9 for masses even at SNR =10=10 (versus $0.74$–$0.87$ scratch).

These results indicate detector- and waveform-agnostic latent representations, enabling rapid transfer across noise environments and waveform models.

6. Foundation-Model Perspective and Downstream Applications

GraviBERT’s pretrained encoder and transformer backbone are adaptable to various downstream tasks by substituting or augmenting the regression head. Downstream application domains include:

  • Overlapping signal separation
  • Anomaly detection
  • Multi-messenger joint inference
  • Tests of general relativity
  • Searches for new physics

Current limitations include:

  • Point estimation only (no posterior inference)
  • Four intrinsic parameters (m1m_1, m2m_2, χeff\chi_\text{eff}, dd)
  • Single-channel analysis
  • No architecture hyperparameter search

Planned future enhancements are full 15-parameter inference (including extrinsic and sky-location parameters), multi-channel/multi-detector extensions, low-frequency ET operation (down to $5$ Hz), and optimized transformer variants.

7. Mathematical Formulations and Metrics

Key formulas underpinning waveform modeling and training include:

  • Waveform Model (leading PN order):

h(t)Mc5/3dω(t)2/3cos[2Φ(t)],Mc=(m1m2)3/5/(m1+m2)1/5h(t) \simeq \frac{\mathcal{M}_c^{5/3}}{d}\,\omega(t)^{2/3} \cos[2\Phi(t)],\quad \mathcal{M}_c=(m_1 m_2)^{3/5}/(m_1 + m_2)^{1/5}

  • Pretraining Loss:

$\mathcal{L}_\text{pretrain} = \frac{1}{|I|}\sum_{(b,i)\in I}\bigl\|\mathbf{z}_\text{target,(b,i,·)} - \mathbf{z}_\text{pred,(b,i,·)}\bigr\|_2^2$

  • Fine-tuning MSE Loss:

LMSE(θ)=1Bi=1Byif(xi;θ)22\mathcal{L}_\text{MSE}(\theta) = \frac{1}{B}\sum_{i=1}^{B} \| y_i - f(x_i;\theta) \|_2^2

  • Evaluation Metrics:

MAEj=1ni=1nyijy^ij,Rj2=1i(yijy^ij)2i(yijyˉj)2\mathrm{MAE}_j = \frac{1}{n} \sum_{i=1}^n |y_{ij} - \hat{y}_{ij}|, \quad R^2_j = 1 - \frac{\sum_i (y_{ij} - \hat{y}_{ij})^2}{\sum_i (y_{ij} - \bar{y}_j)^2}

A plausible implication is that the masked reconstruction pretraining confers the ability to encode universal GW properties in latent representations, supporting generalization across domains and waveforms.


GraviBERT establishes a new paradigm for foundation-style models in GW astrophysics, introducing masked-segment reconstruction for noisy time series and demonstrating substantial gains in parameter inference accuracy and efficiency, domain adaptation, and extensibility to broader multi-messenger inference pipelines (Benedikt et al., 24 Dec 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Whiteboard

Topic to Video (Beta)

Follow Topic

Get notified by email when new papers are published related to GraviBERT.