Papers
Topics
Authors
Recent
Assistant
AI Research Assistant
Well-researched responses based on relevant abstracts and paper content.
Custom Instructions Pro
Preferences or requirements that you'd like Emergent Mind to consider when generating responses.
Gemini 2.5 Flash
Gemini 2.5 Flash 71 tok/s
Gemini 2.5 Pro 58 tok/s Pro
GPT-5 Medium 35 tok/s Pro
GPT-5 High 25 tok/s Pro
GPT-4o 101 tok/s Pro
Kimi K2 236 tok/s Pro
GPT OSS 120B 469 tok/s Pro
Claude Sonnet 4 37 tok/s Pro
2000 character limit reached

Branched Schrödinger Bridge Matching (2506.09007v1)

Published 10 Jun 2025 in cs.LG and q-bio.QM

Abstract: Predicting the intermediate trajectories between an initial and target distribution is a central problem in generative modeling. Existing approaches, such as flow matching and Schr\"odinger Bridge Matching, effectively learn mappings between two distributions by modeling a single stochastic path. However, these methods are inherently limited to unimodal transitions and cannot capture branched or divergent evolution from a common origin to multiple distinct outcomes. To address this, we introduce Branched Schr\"odinger Bridge Matching (BranchSBM), a novel framework that learns branched Schr\"odinger bridges. BranchSBM parameterizes multiple time-dependent velocity fields and growth processes, enabling the representation of population-level divergence into multiple terminal distributions. We show that BranchSBM is not only more expressive but also essential for tasks involving multi-path surface navigation, modeling cell fate bifurcations from homogeneous progenitor states, and simulating diverging cellular responses to perturbations.

Summary

  • The paper introduces BranchSBM, a novel framework that explicitly models branched transitions by learning multiple time-dependent velocity fields and growth rates.
  • It employs a multi-stage training process—trajectory interpolation, flow matching, and growth network learning—to minimize kinetic energy and state costs while ensuring mass conservation.
  • The method outperforms single-branch approaches in applications like LiDAR navigation and cell differentiation by accurately reconstructing diverse target distributions.

Predicting how a system evolves over time, especially when starting from a known state and moving towards a desired state, is a fundamental challenge in many scientific and engineering fields. Existing methods for generative modeling, such as flow matching and Schrödinger Bridge Matching (SBM), are effective at learning continuous transformations between two distributions by modeling a single stochastic path. However, many real-world phenomena involve divergence, where a single initial state can evolve into multiple distinct outcomes. This limitation prevents standard SBM and similar methods from capturing complex, branched dynamics.

The paper "Branched Schrödinger Bridge Matching" (2506.09007) introduces BranchSBM, a novel framework designed to explicitly model these branched transitions. Instead of a single trajectory, BranchSBM learns multiple time-dependent velocity fields and associated growth processes. This allows it to represent how a population originating from a common initial distribution π0\pi_0 can diverge into multiple distinct target distributions {π1,k}k=0K\{\pi_{1, k}\}_{k=0}^K. The core idea is to solve the Branched Generalized Schrödinger Bridge (GSB) problem, which finds optimal stochastic paths that minimize an energy objective (kinetic energy plus a state cost) while satisfying endpoint distributions and accounting for mass changes (growth or destruction) along each branch.

BranchSBM formulates the Branched GSB problem as the sum of Unbalanced Conditional Stochastic Optimal Control (CondSOC) problems. For each branch kk, the goal is to find an optimal drift field ut,k(Xt)u_{t, k}(X_t) and a growth rate gt,k(Xt)g_{t, k}(X_t) that minimize a weighted expected energy, conditioned on paired initial and target samples (x0,x1,k)(\mathbf{x}_0, \mathbf{x}_{1, k}). The weight wt,k(Xt)w_{t, k}(X_t) of a sample on branch kk at time tt evolves according to its growth rate gt,kg_{t, k}: wt,k=w0,k+0tgs,k(Xs)dsw_{t, k} = w_{0, k} + \int_0^t g_{s, k}(X_s) ds. For the primary branch (k=0k=0), the initial weight is 1, while for secondary branches (k>0k>0), it's 0. The growth rates gt,kg_{t, k} determine how mass is transferred between branches over time.

To implement BranchSBM, the authors propose a multi-stage training approach using neural networks to parameterize the drift and growth fields.

  1. Stage 1: Branched Neural Interpolant Optimization: A neural network φt,η\varphi_{t, \eta} is trained to predict an optimal interpolating path xt,η,k\mathbf{x}_{t, \eta, k} between initial sample x0\mathbf{x}_0 and target sample x1,k\mathbf{x}_{1, k} for each branch kk. The interpolant is defined as xt,η,k=(1t)x0+tx1,k+t(1t)φt,η(x0,x1,k)\mathbf{x}_{t, \eta, k} = (1-t)\mathbf{x}_0 + t\mathbf{x}_{1, k} + t(1-t)\varphi_{t, \eta}(\mathbf{x}_0, \mathbf{x}_{1, k}). This network is trained by minimizing a trajectory loss (Ltraj\mathcal{L}_{\text{traj}}) which comprises the kinetic energy of the path's velocity x˙t,η,k\dot{\mathbf{x}}_{t, \eta, k} and a task-specific state cost Vt(xt,η,k)V_t(\mathbf{x}_{t, \eta, k}). This stage effectively learns the desired conditional paths and their velocities x˙t,η,k\dot{\mathbf{x}}_{t, \eta, k} given endpoint pairs.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    
    # Pseudocode for Stage 1 (Trajectory Interpolant)
    for epoch in range(num_epochs_stage1):
        for batch in dataloader:
            x0_batch, x1_branches_batch = batch # x1_branches_batch is a list of target batches
            t = sample_time(batch_size) # Sample time uniformly in [0, 1]
    
            total_loss = 0
            for k in range(num_branches):
                x1_k_batch = x1_branches_batch[k]
    
                # Calculate interpolated position and velocity using phi_eta
                # phi_eta is a neural network
                phi_output = phi_eta(x0_batch, x1_k_batch, t)
                xt_eta_k = (1-t) * x0_batch + t * x1_k_batch + t * (1-t) * phi_output
                d_phi_output_dt = # Calculate time derivative of phi_output
                dxt_eta_k_dt = x1_k_batch - x0_batch + t * (1-t) * d_phi_output_dt + (1-2*t) * phi_output
    
                # Calculate state cost V_t(xt_eta_k)
                Vt_xt_eta_k = calculate_state_cost(xt_eta_k, data_manifold)
    
                # Compute trajectory loss
                loss_traj_k = 0.5 * (dxt_eta_k_dt ** 2).sum(dim=1).mean() + Vt_xt_eta_k.mean()
                total_loss += loss_traj_k
    
            optimizer_phi.zero_grad()
            total_loss.backward()
            optimizer_phi.step()
  2. Stage 2: Flow Matching: A separate neural network ut,kθu_{t, k}^\theta is trained for each branch kk. These networks are trained to match the conditional velocities x˙t,η,k\dot{\mathbf{x}}_{t, \eta, k} learned in Stage 1, using a conditional flow matching loss (Lflow\mathcal{L}_{\text{flow}}). This stage learns the state-dependent drift fields ut,kθ(Xt)u_{t, k}^\theta(X_t) for each branch.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    
    # Pseudocode for Stage 2 (Flow Network Training)
    # phi_star is the trained network from Stage 1
    for epoch in range(num_epochs_stage2):
        for batch in dataloader:
            x0_batch, x1_branches_batch = batch
            t = sample_time(batch_size)
    
            total_loss = 0
            for k in range(num_branches):
                x1_k_batch = x1_branches_batch[k]
    
                # Get target velocity from trained interpolant
                with torch.no_grad():
                     # Calculate interpolated position and target velocity
                    phi_output = phi_star(x0_batch, x1_k_batch, t)
                    xt_eta_k = (1-t) * x0_batch + t * x1_k_batch + t * (1-t) * phi_output
                    d_phi_output_dt = # Calculate time derivative of phi_output
                    target_velocity = x1_k_batch - x0_batch + t * (1-t) * d_phi_output_dt + (1-2*t) * phi_output
    
                # Predict velocity using flow network u_theta_k
                predicted_velocity = u_theta_k[k](xt_eta_k, t)
    
                # Compute flow matching loss
                loss_flow_k = ((predicted_velocity - target_velocity) ** 2).sum(dim=1).mean()
                total_loss += loss_flow_k
    
            optimizer_u.zero_grad()
            total_loss.backward()
            optimizer_u.step()
  3. Stage 3: Growth Network Training: The flow network parameters are frozen. Separate neural networks gt,kϕg_{t, k}^\phi are trained for each branch to model the growth rates. This is done by minimizing a combination of losses: the Branched Energy Loss (Lenergy\mathcal{L}_{\text{energy}}), which minimizes the energy of the trajectories weighted by the predicted mass of each branch; the Weight Matching Loss (Lmatch\mathcal{L}_{\text{match}}), which matches the predicted final mass of each branch to the true target weights; and the Mass Conservation Loss (Lmass\mathcal{L}_{\text{mass}}), which enforces that the total mass across all branches sums to the expected total mass at any time. A growth penalty is included for regularization.
  4. Stage 4: Final Joint Training: All network parameters (θ\theta for flow, ϕ\phi for growth) are unfrozen and jointly optimized using the loss from Stage 3, plus a Reconstruction Loss (Lrecons\mathcal{L}_{\text{recons}}) that penalizes deviation of the simulated endpoint distribution from the true target distribution using nearest neighbors.

The state cost Vt(Xt)V_t(X_t) plays a crucial role in guiding the trajectories. It is derived from a data-dependent Riemannian metric G(Xt,D)\mathbf{G}(X_t, \mathcal{D}). For low-dimensional data (LiDAR, 2D scRNA-seq), the LAND metric is used, penalizing movement away from data points based on a Gaussian kernel. For high-dimensional data (gene expression PCs), the RBF metric is employed, which learns to assign low cost within data clusters and high cost outside.

BranchSBM's practical capabilities are demonstrated across three distinct applications:

  • Branched LiDAR Surface Navigation: BranchSBM successfully navigates from a single starting distribution to two target distributions located on different sides of a 3D LiDAR manifold. The model learns non-linear paths that follow the terrain and accurately simulates the transfer of mass from the initial state to the two targets over time (Figure 1, Figure 2). Quantitatively, it significantly outperforms single-branch SBM in reconstructing the target distributions (Table 1), which would simply average the two targets and fail to capture the branching.
  • Differentiating Single-Cell Population Dynamics: Applied to mouse hematopoiesis scRNA-seq data, BranchSBM models how homogeneous progenitor cells differentiate into two distinct cell fates. It learns branching trajectories that capture the observed dynamics between three time points, including accurately simulating the cell distribution at an intermediate, unseen time point (Figure 3). Single-branch SBM fails to capture the distinct destinations, resulting in poor intermediate and final state reconstruction (Table 2, Appendix Figure D.1).
  • Cell-State Perturbation Modeling: BranchSBM is used to model how drug perturbations cause a single cell line to diverge into multiple heterogeneous states in high-dimensional gene expression space (50, 100, 150 PCs). The model successfully learns trajectories to 2 (Clonidine) or 3 (Trametinib) distinct perturbed cell clusters (Figure 4, Figure 5, Appendix Figure D.1). Benchmarking against single-branch SBM confirms that BranchSBM is necessary to accurately reconstruct the full set of divergent perturbed states in high dimensions, where single-branch SBM collapses modes (Tables 3 & 4).

The implementation uses standard MLPs for the networks. Softplus activation is applied to the output of secondary branch growth networks to ensure non-negativity, reflecting that mass is only added to these branches from the primary. The multi-stage training, outlined in detail in Algorithm 1, is key to stable optimization. The use of Optimal Transport to pair initial and target samples in unpaired datasets provides a principled way to define the conditional endpoint pairs needed for training. While training involves multiple networks, the authors note that computational overhead remains manageable compared to single-branch SBM, as each branch network is trained on its corresponding data subset. Inference is efficient, requiring simulation of trajectories from a single initial sample.

In conclusion, BranchSBM provides a theoretically grounded and practically effective method for modeling complex branched stochastic dynamics. By extending the Schrödinger Bridge framework to explicitly account for multiple target distributions and mass evolution via learned growth rates, it fills a critical gap in generative modeling capabilities, with demonstrated success in diverse applications like complex navigation and biological fate mapping.

List To Do Tasks Checklist Streamline Icon: https://streamlinehq.com

Collections

Sign up for free to add this paper to one or more collections.

X Twitter Logo Streamline Icon: https://streamlinehq.com

Tweets

This paper has been mentioned in 5 posts and received 266 likes.

Youtube Logo Streamline Icon: https://streamlinehq.com

Don't miss out on important new AI/ML research

See which papers are being discussed right now on X, Reddit, and more:

“Emergent Mind helps me see which AI papers have caught fire online.”

Philip

Philip

Creator, AI Explained on YouTube