- The paper introduces BranchSBM, a novel framework that explicitly models branched transitions by learning multiple time-dependent velocity fields and growth rates.
- It employs a multi-stage training process—trajectory interpolation, flow matching, and growth network learning—to minimize kinetic energy and state costs while ensuring mass conservation.
- The method outperforms single-branch approaches in applications like LiDAR navigation and cell differentiation by accurately reconstructing diverse target distributions.
Predicting how a system evolves over time, especially when starting from a known state and moving towards a desired state, is a fundamental challenge in many scientific and engineering fields. Existing methods for generative modeling, such as flow matching and Schrödinger Bridge Matching (SBM), are effective at learning continuous transformations between two distributions by modeling a single stochastic path. However, many real-world phenomena involve divergence, where a single initial state can evolve into multiple distinct outcomes. This limitation prevents standard SBM and similar methods from capturing complex, branched dynamics.
The paper "Branched Schrödinger Bridge Matching" (2506.09007) introduces BranchSBM, a novel framework designed to explicitly model these branched transitions. Instead of a single trajectory, BranchSBM learns multiple time-dependent velocity fields and associated growth processes. This allows it to represent how a population originating from a common initial distribution π0 can diverge into multiple distinct target distributions {π1,k}k=0K. The core idea is to solve the Branched Generalized Schrödinger Bridge (GSB) problem, which finds optimal stochastic paths that minimize an energy objective (kinetic energy plus a state cost) while satisfying endpoint distributions and accounting for mass changes (growth or destruction) along each branch.
BranchSBM formulates the Branched GSB problem as the sum of Unbalanced Conditional Stochastic Optimal Control (CondSOC) problems. For each branch k, the goal is to find an optimal drift field ut,k(Xt) and a growth rate gt,k(Xt) that minimize a weighted expected energy, conditioned on paired initial and target samples (x0,x1,k). The weight wt,k(Xt) of a sample on branch k at time t evolves according to its growth rate gt,k: wt,k=w0,k+∫0tgs,k(Xs)ds. For the primary branch (k=0), the initial weight is 1, while for secondary branches (k>0), it's 0. The growth rates gt,k determine how mass is transferred between branches over time.
To implement BranchSBM, the authors propose a multi-stage training approach using neural networks to parameterize the drift and growth fields.
- Stage 1: Branched Neural Interpolant Optimization: A neural network φt,η is trained to predict an optimal interpolating path xt,η,k between initial sample x0 and target sample x1,k for each branch k. The interpolant is defined as xt,η,k=(1−t)x0+tx1,k+t(1−t)φt,η(x0,x1,k). This network is trained by minimizing a trajectory loss (Ltraj) which comprises the kinetic energy of the path's velocity x˙t,η,k and a task-specific state cost Vt(xt,η,k). This stage effectively learns the desired conditional paths and their velocities x˙t,η,k given endpoint pairs.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
|
# Pseudocode for Stage 1 (Trajectory Interpolant)
for epoch in range(num_epochs_stage1):
for batch in dataloader:
x0_batch, x1_branches_batch = batch # x1_branches_batch is a list of target batches
t = sample_time(batch_size) # Sample time uniformly in [0, 1]
total_loss = 0
for k in range(num_branches):
x1_k_batch = x1_branches_batch[k]
# Calculate interpolated position and velocity using phi_eta
# phi_eta is a neural network
phi_output = phi_eta(x0_batch, x1_k_batch, t)
xt_eta_k = (1-t) * x0_batch + t * x1_k_batch + t * (1-t) * phi_output
d_phi_output_dt = # Calculate time derivative of phi_output
dxt_eta_k_dt = x1_k_batch - x0_batch + t * (1-t) * d_phi_output_dt + (1-2*t) * phi_output
# Calculate state cost V_t(xt_eta_k)
Vt_xt_eta_k = calculate_state_cost(xt_eta_k, data_manifold)
# Compute trajectory loss
loss_traj_k = 0.5 * (dxt_eta_k_dt ** 2).sum(dim=1).mean() + Vt_xt_eta_k.mean()
total_loss += loss_traj_k
optimizer_phi.zero_grad()
total_loss.backward()
optimizer_phi.step() |
- Stage 2: Flow Matching: A separate neural network ut,kθ is trained for each branch k. These networks are trained to match the conditional velocities x˙t,η,k learned in Stage 1, using a conditional flow matching loss (Lflow). This stage learns the state-dependent drift fields ut,kθ(Xt) for each branch.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
# Pseudocode for Stage 2 (Flow Network Training)
# phi_star is the trained network from Stage 1
for epoch in range(num_epochs_stage2):
for batch in dataloader:
x0_batch, x1_branches_batch = batch
t = sample_time(batch_size)
total_loss = 0
for k in range(num_branches):
x1_k_batch = x1_branches_batch[k]
# Get target velocity from trained interpolant
with torch.no_grad():
# Calculate interpolated position and target velocity
phi_output = phi_star(x0_batch, x1_k_batch, t)
xt_eta_k = (1-t) * x0_batch + t * x1_k_batch + t * (1-t) * phi_output
d_phi_output_dt = # Calculate time derivative of phi_output
target_velocity = x1_k_batch - x0_batch + t * (1-t) * d_phi_output_dt + (1-2*t) * phi_output
# Predict velocity using flow network u_theta_k
predicted_velocity = u_theta_k[k](xt_eta_k, t)
# Compute flow matching loss
loss_flow_k = ((predicted_velocity - target_velocity) ** 2).sum(dim=1).mean()
total_loss += loss_flow_k
optimizer_u.zero_grad()
total_loss.backward()
optimizer_u.step() |
- Stage 3: Growth Network Training: The flow network parameters are frozen. Separate neural networks gt,kϕ are trained for each branch to model the growth rates. This is done by minimizing a combination of losses: the Branched Energy Loss (Lenergy), which minimizes the energy of the trajectories weighted by the predicted mass of each branch; the Weight Matching Loss (Lmatch), which matches the predicted final mass of each branch to the true target weights; and the Mass Conservation Loss (Lmass), which enforces that the total mass across all branches sums to the expected total mass at any time. A growth penalty is included for regularization.
- Stage 4: Final Joint Training: All network parameters (θ for flow, ϕ for growth) are unfrozen and jointly optimized using the loss from Stage 3, plus a Reconstruction Loss (Lrecons) that penalizes deviation of the simulated endpoint distribution from the true target distribution using nearest neighbors.
The state cost Vt(Xt) plays a crucial role in guiding the trajectories. It is derived from a data-dependent Riemannian metric G(Xt,D). For low-dimensional data (LiDAR, 2D scRNA-seq), the LAND metric is used, penalizing movement away from data points based on a Gaussian kernel. For high-dimensional data (gene expression PCs), the RBF metric is employed, which learns to assign low cost within data clusters and high cost outside.
BranchSBM's practical capabilities are demonstrated across three distinct applications:
- Branched LiDAR Surface Navigation: BranchSBM successfully navigates from a single starting distribution to two target distributions located on different sides of a 3D LiDAR manifold. The model learns non-linear paths that follow the terrain and accurately simulates the transfer of mass from the initial state to the two targets over time (Figure 1, Figure 2). Quantitatively, it significantly outperforms single-branch SBM in reconstructing the target distributions (Table 1), which would simply average the two targets and fail to capture the branching.
- Differentiating Single-Cell Population Dynamics: Applied to mouse hematopoiesis scRNA-seq data, BranchSBM models how homogeneous progenitor cells differentiate into two distinct cell fates. It learns branching trajectories that capture the observed dynamics between three time points, including accurately simulating the cell distribution at an intermediate, unseen time point (Figure 3). Single-branch SBM fails to capture the distinct destinations, resulting in poor intermediate and final state reconstruction (Table 2, Appendix Figure D.1).
- Cell-State Perturbation Modeling: BranchSBM is used to model how drug perturbations cause a single cell line to diverge into multiple heterogeneous states in high-dimensional gene expression space (50, 100, 150 PCs). The model successfully learns trajectories to 2 (Clonidine) or 3 (Trametinib) distinct perturbed cell clusters (Figure 4, Figure 5, Appendix Figure D.1). Benchmarking against single-branch SBM confirms that BranchSBM is necessary to accurately reconstruct the full set of divergent perturbed states in high dimensions, where single-branch SBM collapses modes (Tables 3 & 4).
The implementation uses standard MLPs for the networks. Softplus activation is applied to the output of secondary branch growth networks to ensure non-negativity, reflecting that mass is only added to these branches from the primary. The multi-stage training, outlined in detail in Algorithm 1, is key to stable optimization. The use of Optimal Transport to pair initial and target samples in unpaired datasets provides a principled way to define the conditional endpoint pairs needed for training. While training involves multiple networks, the authors note that computational overhead remains manageable compared to single-branch SBM, as each branch network is trained on its corresponding data subset. Inference is efficient, requiring simulation of trajectories from a single initial sample.
In conclusion, BranchSBM provides a theoretically grounded and practically effective method for modeling complex branched stochastic dynamics. By extending the Schrödinger Bridge framework to explicitly account for multiple target distributions and mass evolution via learned growth rates, it fills a critical gap in generative modeling capabilities, with demonstrated success in diverse applications like complex navigation and biological fate mapping.