Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
132 tokens/sec
GPT-4o
28 tokens/sec
Gemini 2.5 Pro Pro
42 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Linear-Time Gromov Wasserstein Distances using Low Rank Couplings and Costs (2106.01128v2)

Published 2 Jun 2021 in cs.LG and stat.ML

Abstract: The ability to align points across two related yet incomparable point clouds (e.g. living in different spaces) plays an important role in machine learning. The Gromov-Wasserstein (GW) framework provides an increasingly popular answer to such problems, by seeking a low-distortion, geometry-preserving assignment between these points. As a non-convex, quadratic generalization of optimal transport (OT), GW is NP-hard. While practitioners often resort to solving GW approximately as a nested sequence of entropy-regularized OT problems, the cubic complexity (in the number $n$ of samples) of that approach is a roadblock. We show in this work how a recent variant of the OT problem that restricts the set of admissible couplings to those having a low-rank factorization is remarkably well suited to the resolution of GW: when applied to GW, we show that this approach is not only able to compute a stationary point of the GW problem in time $O(n2)$, but also uniquely positioned to benefit from the knowledge that the initial cost matrices are low-rank, to yield a linear time $O(n)$ GW approximation. Our approach yields similar results, yet orders of magnitude faster computation than the SoTA entropic GW approaches, on both simulated and real data.

Citations (54)

Summary

  • The paper introduces a novel linear-time algorithm for Gromov-Wasserstein distances by exploiting low-rank factorizations in both cost matrices and coupling structures.
  • It reduces the cubic complexity of traditional methods to linear complexity by approximating input costs and constraining the coupling matrix.
  • The approach demonstrates practical scalability on large-scale datasets such as single-cell genomics and human brain data while preserving alignment accuracy.

The paper "Linear-Time Gromov Wasserstein Distances using Low Rank Couplings and Costs" (Linear-Time Gromov Wasserstein Distances using Low Rank Couplings and Costs, 2021) addresses the computational bottleneck of the Gromov-Wasserstein (GW) problem, which is widely used for aligning and comparing data from different metric spaces, such as point clouds or distributions living in heterogeneous feature spaces. The standard approach to solving GW approximately, based on iteratively solving entropy-regularized Optimal Transport (OT) problems, suffers from cubic complexity O(n3)O(n^3) in the number of samples nn, making it impractical for large datasets. This paper proposes novel methods to reduce the GW computation time, ultimately achieving a linear-time algorithm by exploiting low-rank structures in both the input cost matrices and the coupling matrix.

The standard GW problem seeks a coupling matrix PRn×mP \in \mathbb{R}^{n \times m} between two discrete measures with nn and mm samples, represented by cost matrices ARn×nA \in \mathbb{R}^{n \times n} and BRm×mB \in \mathbb{R}^{m \times m} that encode the geometry within each space. The objective is to minimize a quadratic function QA,B(P)\mathcal{Q}_{A,B}(P) that measures the distortion introduced by the coupling PP. The standard entropic GW approximation solves this non-convex problem iteratively using Mirror Descent, which boils down to solving a sequence of entropy-regularized OT problems with a synthetic cost matrix Ct=4APt1BC_t = -4 A P_{t-1} B. The computational bottleneck arises from two main steps in each iteration:

  1. Updating the cost matrix Ct=4APt1BC_t = -4 A P_{t-1} B, which requires O(n2m+nm2)O(n^2 m + nm^2) operations.
  2. Evaluating the GW objective QA,B(P)\mathcal{Q}_{A,B}(P), which also costs O(n2m+nm2)O(n^2 m + nm^2) naively, but can be computed more efficiently using a reformulation in O(n2m+nm2)O(n^2 m + nm^2) operations.
  3. Solving the entropy-regularized OT problem minPΠa,bCt,PεH(P)\min_{P \in \Pi_{a,b}} \langle C_t, P \rangle - \varepsilon H(P) using Sinkhorn's algorithm takes O(nm)O(nm) operations per iteration of Sinkhorn, but the dominant cost per outer GW iteration remains the O(n2m+nm2)O(n^2 m + nm^2) cost of updating CtC_t.

The paper tackles this cubic complexity by introducing two independent strategies and then showing how to combine them:

1. Low-rank (Approximated) Costs:

If the input cost matrices AA and BB admit low-rank factorizations, A=A1A2TA = A_1 A_2^T and B=B1B2TB = B_1 B_2^T where A1,A2Rn×dA_1, A_2 \in \mathbb{R}^{n \times d} and B1,B2Rm×dB_1, B_2 \in \mathbb{R}^{m \times d'} with dnd \ll n and dmd' \ll m, the complexity of updating the synthetic cost matrix CtC_t can be reduced. Ct=4A1A2TPt1B1B2TC_t = -4 A_1 A_2^T P_{t-1} B_1 B_2^T. Computing this product can be done more efficiently: first compute G=A2TPt1B1G = A_2^T P_{t-1} B_1 in O(nmd+mdd)O(nmd + mdd') operations, then Ct=4A1GB2TC_t = -4 A_1 G B_2^T in O(ndd+nmd)O(ndd' + nm d'). If d,dd, d' are small constants, this reduces the update cost to O(nm(d+d))O(nm(d+d')), which is O(nm)O(nm) per outer iteration if nmn \approx m, or O(n2)O(n^2) if n=mn=m. Similarly, the evaluation of QA,B(P)\mathcal{Q}_{A,B}(P) can be sped up to O(nmd+mdd)O(nmd + mdd'). This strategy is particularly relevant for squared Euclidean distance matrices, where an exact low-rank factorization exists with rank related to the ambient dimension (d+2d+2). For general distance matrices, recent work allows for computing low-rank approximations in nearly linear time, enabling this speedup even when an exact factorization is not obvious. Algorithm 2 outlines this "Quadratic Entropic-GW" approach.

2. Low-rank Constraints for Couplings:

Instead of assuming low-rank input costs, the paper proposes constraining the coupling matrix PP to have a low nonnegative rank. This is achieved by restricting PP to the form P=QDiag(1/g)RTP = Q \text{Diag}(1/g) R^T, where QRn×r,RRm×rQ \in \mathbb{R}^{n \times r}, R \in \mathbb{R}^{m \times r} are matrices satisfying certain marginal constraints (specifically, Q1r=aQ\mathbf{1}_r = a, R1r=bR\mathbf{1}_r = b, QT1n=RT1m=gQ^T\mathbf{1}_n = R^T\mathbf{1}_m = g) and gΔrg \in \Delta_r is a common intermediate marginal. This factorization implies PP has a nonnegative rank at most rr. The GW problem is then reformulated as minimizing QA,B(QDiag(1/g)RT)\mathcal{Q}_{A,B}(Q \text{Diag}(1/g) R^T) over (Q,R,g)(Q, R, g) in the feasible set C(a,b,r)\mathcal{C}(a, b, r). This problem is solved using a Mirror Descent scheme w.r.t. the KL divergence in the space of (Q,R,g)(Q, R, g). Each step involves computing generalized kernel matrices K(1),K(2),K(3)K^{(1)}, K^{(2)}, K^{(3)} and then solving a barycenter problem efficiently using Dykstra's algorithm (Algorithm 3). The initialization uses a low-rank approximation of a lower bound based on the squared norms of the rows/columns of A2aA^{\odot 2}a and B2bB^{\odot 2}b. This initialization itself can be computed efficiently. While Dykstra's algorithm for the barycenter step takes O((n+m)r)O((n+m)r) operations per inner iteration, computing the kernel matrices K(1),K(2),K(3)K^{(1)}, K^{(2)}, K^{(3)} still involves matrix products like APkBA P_k B, which require O(n2m+nm2)O(n^2 m + nm^2) operations in the general case. Thus, this approach alone reduces the complexity per outer iteration to O((n2+m2)r)O((n^2 + m^2)r).

3. Double Low-rank GW:

The key contribution is showing that combining both low-rank strategies yields a linear-time algorithm. If both cost matrices A,BA, B have low-rank factorizations (A=A1A2T,B=B1B2TA=A_1A_2^T, B=B_1B_2^T) and the coupling PP is constrained to be low-rank (P=QDiag(1/g)RTP=Q \text{Diag}(1/g)R^T), the critical computation APBA P B becomes A1A2TQDiag(1/g)RTB1B2TA_1 A_2^T Q \text{Diag}(1/g) R^T B_1 B_2^T. This can be computed in O(nrd+mrd+r2dd)O(n r d + m r d' + r^2 d d') operations by first computing A2TQA_2^T Q (in O(nrd)O(nrd)), then (A2TQ)Diag(1/g)(A_2^T Q) \text{Diag}(1/g) (in O(nr)O(nr)), then ((A2TQ)Diag(1/g))RT((A_2^T Q) \text{Diag}(1/g)) R^T (in O(nrm)O(nrm)), etc. More strategically, one can compute intermediate low-rank factors of APBAPB: C1=A1(A2TQ)Diag(1/g)C_1 = -A_1 (A_2^T Q) \text{Diag}(1/g) and C2=(RTB1)B2TC_2 = (R^T B_1) B_2^T. The necessary terms for the kernel matrices can then be computed from these factors efficiently. Specifically, terms like APkBRkDiag(1/gk)A P_k B R_k \text{Diag}(1/g_k) can be computed efficiently under these double low-rank assumptions. For instance, (APkB)RkDiag(1/gk)=A1(A2TQk)Diag(1/gk)(RkTB1)B2TRkDiag(1/gk)(A P_k B) R_k \text{Diag}(1/g_k) = A_1 (A_2^T Q_k) \text{Diag}(1/g_k) (R_k^T B_1) B_2^T R_k \text{Diag}(1/g_k). This still seems complex, but the Mirror Descent updates in the low-rank coupling formulation (Algorithm 3) involve terms like APkBRkDiag(1/gk)A P_k B R_k \text{Diag}(1/g_k) (for K(1)K^{(1)}) and BPkTAQkDiag(1/gk)B P_k^T A Q_k \text{Diag}(1/g_k) (for K(2)K^{(2)}), and D(QkTAPkBRk)\mathcal{D}(Q_k^T A P_k B R_k) (for K(3)K^{(3)}). Using Pk=QkDiag(1/gk)RkTP_k = Q_k \text{Diag}(1/g_k) R_k^T, these terms become:

  • AQkDiag(1/gk)RkTBRkDiag(1/gk)A Q_k \text{Diag}(1/g_k) R_k^T B R_k \text{Diag}(1/g_k)
  • BRkDiag(1/gk)QkTAQkDiag(1/gk)B R_k \text{Diag}(1/g_k) Q_k^T A Q_k \text{Diag}(1/g_k)
  • D(QkTAQkDiag(1/gk)RkTBRk)\mathcal{D}(Q_k^T A Q_k \text{Diag}(1/g_k) R_k^T B R_k) If A=A1A2TA=A_1A_2^T and B=B1B2TB=B_1B_2^T, these can be computed in O(n(rd+r2)+m(rd+r2)+r2(d+d))O(n(rd+r^2) + m(rd'+r^2) + r^2(d+d')) operations per outer iteration. For instance, A1(A2TQk)Diag(1/gk)(RkTB1)B2TRkDiag(1/gk)A_1(A_2^T Q_k) \text{Diag}(1/g_k) (R_k^T B_1) B_2^T R_k \text{Diag}(1/g_k):
  • A2TQkA_2^T Q_k: O(nrd)O(nrd)
  • (A2TQk)Diag(1/gk)(A_2^T Q_k) \text{Diag}(1/g_k): O(nr)O(nr)
  • RkTB1R_k^T B_1: O(mrd)O(mrd')
  • Product of (n×r)(n \times r) and (r×r)(r \times r) and (r×m)(r \times m) involves intermediate O(nr2+r2m)O(nr^2 + r^2m) operations, leading to O(nmd)O(nmd) overall for APBAPB. This was not the linearization.

The linear time comes from carefully re-evaluating the gradients in the Mirror Descent update (Equation 6 in the paper): The gradient w.r.t QQ involves APBRDiag(1/g)APB R \text{Diag}(1/g). With factorizations: A1A2TQDiag(1/g)RTB1B2TRDiag(1/g)A_1 A_2^T Q \text{Diag}(1/g) R^T B_1 B_2^T R \text{Diag}(1/g). This can be computed as A1(A2TQDiag(1/g)(RTB1)(B2TR)Diag(1/g))A_1 \left( A_2^T Q \text{Diag}(1/g) (R^T B_1) (B_2^T R) \text{Diag}(1/g) \right). The term in parentheses is r×rr \times r. A2TQA_2^T Q takes O(nrd)O(nrd). RTB1R^T B_1 takes O(mrd)O(mrd'). B2TRB_2^T R takes O(mrd)O(mrd'). A2TQA_2^T Q is d×rd \times r. (A2TQ)Diag(1/g)(A_2^T Q) \text{Diag}(1/g) is d×rd \times r. (RTB1)(R^T B_1) is r×dr \times d'. (B2TR)(B_2^T R) is d×rd' \times r. The term M=A2TQM = A_2^T Q is d×rd \times r, N=RTB1N = R^T B_1 is r×dr \times d', O=B2TRO = B_2^T R is d×rd' \times r. The inner parentheses computation is MDiag(1/g)NODiag(1/g)M \text{Diag}(1/g) N O \text{Diag}(1/g). This is a product of matrices of sizes (d×r)(d \times r), (r×r)(r \times r), (r×d)(r \times d'), (d×r)(d' \times r), (r×r)(r \times r). The central term (r×d)×(d×r)(r \times d') \times (d' \times r) is r×rr \times r. (r×r)×(r×r)(r \times r) \times (r \times r) is r×rr \times r. (d×r)×(r×r)(d \times r) \times (r \times r) is d×rd \times r. Final matrix multiplication involves (d×r)(d \times r) with Diag(1/g)\text{Diag}(1/g) and A1(n×d)A_1 (n \times d). The most expensive matrix products are:

  • A2TQA_2^T Q: O(nrd)O(nrd)
  • RTB1R^T B_1: O(mrd)O(mrd')
  • B2TRB_2^T R: O(mrd)O(mrd')
  • MDiag(1/g)NODiag(1/g)M \text{Diag}(1/g) N O \text{Diag}(1/g): inner products NON O take O(rdr)O(rd'r), then products M()Diag(1/g)M (\dots) \text{Diag}(1/g) take O(dr2+dr2)O(dr^2 + d'r^2). Total O(rdr+dr2+dr2)O(rd'r + dr^2 + d'r^2).
  • Outer product with A1A_1: A1(d×r)A_1 (d \times r). This is n×rn \times r, cost O(ndr)O(ndr). So the gradient w.r.t Q computation is dominated by O(nrd+mrd+rdr+dr2+dr2)O(nrd + mrd' + rd'r + dr^2 + d'r^2). If r,d,dr, d, d' are constants, this is O(n+m)O(n+m). If r,d,dr, d, d' are O(logn)\mathcal{O}(\log n) or O(poly(logn))\mathcal{O}(\text{poly}(\log n)), it's still close to linear. Similarly, evaluating QA,B(P)=A2a,a+B2b,b2APB,P\mathcal{Q}_{A,B}(P) = \langle A^{\odot 2} a, a \rangle + \langle B^{\odot 2} b, b \rangle - 2 \langle A P B, P \rangle. APB,P=A1A2TQDiag(1/g)RTB1B2T,QDiag(1/g)RT\langle A P B, P \rangle = \langle A_1 A_2^T Q \text{Diag}(1/g) R^T B_1 B_2^T, Q \text{Diag}(1/g) R^T \rangle. This can be computed efficiently by recognizing it's related to G1G2G_1 \odot G_2 terms, where G1=A1TPB2G_1=A_1^T P B_2 and G2=A2TPB1G_2=A_2^T P B_1. A1TQDiag(1/g)RTB2A_1^T Q \text{Diag}(1/g) R^T B_2 (d×dd \times d') and A2TQDiag(1/g)RTB1A_2^T Q \text{Diag}(1/g) R^T B_1 (d×dd \times d'). These can be computed in O(nrd+mrd)O(nrd + mrd'). The dot product 1dT(G1G2)1d\mathbf{1}_d^T(G_1 \odot G_2)\mathbf{1}_{d'} is O(dd)O(dd'). The terms A2a,a\langle A^{\odot 2} a, a \rangle and B2b,b\langle B^{\odot 2} b, b \rangle can also be computed in nearly linear time if AA and BB are factorized as A=A1A2TA=A_1A_2^T etc., exploiting the fact that A2A^{\odot 2} can also be factorized, though with potentially larger rank d2d^2. Computing x~=A2a\tilde{x} = A^{\odot 2} a costs O(nd2)O(nd^2). Thus, under both low-rank assumptions, the computation per outer iteration becomes linear O((n+m)(r+d+d)+r2(d+d)+ddr)O((n+m)(r+d+d') + r^2(d+d') + d d' r). With r,d,dr,d,d' small, this is O(n+m)O(n+m).

Implementation and Applications:

The paper provides algorithms for the quadratic (Algorithm 2, 3) and linear (Section 5 combining aspects of 2 and 3) methods.

  • Initialization: A warm start is crucial for non-convex optimization. The proposed initialization uses a low-rank OT problem based on the squared norms of the row/column norms of A2aA^{\odot 2}a and B2bB^{\odot 2}b, which can be computed in linear time under low-rank cost assumptions.
  • Optimization: Mirror Descent is used. For the low-rank coupling method, each MD step requires solving a barycenter problem using Dykstra's algorithm. The paper shows experimentally that the number of Dykstra iterations doesn't heavily depend on nn, which is favorable.
  • Hyperparameters: The method has hyperparameters like the low rank rr and the step size γ\gamma (or regularization ε\varepsilon if double regularization is used). The paper explores the sensitivity to γ\gamma and a lower bound α\alpha on entries of gg, finding the method relatively robust. The choice of rr affects the quality of the approximation; ideally, it should relate to the intrinsic dimension or number of clusters in the data.
  • Computational Cost: The paper provides clear complexity analyses: O(n3)O(n^3) for standard entropic GW, O(n2)O(n^2) for quadratic GW (low-rank costs or low-rank couplings separately), and O(n)O(n) for linear GW (both low-rank costs and couplings).
  • Real-world Applications: The methods are demonstrated on single-cell genomics data (SNAREseq and Splatter) and a human brain dataset (BRAIN). These applications involve aligning point clouds representing cells characterized by different molecular features (e.g., gene expression and chromatin accessibility). The distance metric used is often based on k-NN graphs and shortest paths. While shortest path distance matrices don't automatically admit low-rank factorizations like Euclidean distance, the quadratic version of the algorithm is applicable by simply computing the full distance matrix. The linear version is demonstrated on the BRAIN dataset using squared Euclidean distance after PCA, where the low-rank factorization is available.
  • Performance: Experiments show that the proposed LR (quadratic) and Lin LR (linear) methods achieve similar GW loss and downstream task performance (like cell type alignment measured by FOSCTTM) compared to the standard Entropic-GW and MREC baselines, but are orders of magnitude faster, particularly at large scales (n=m=105n=m=10^5). Lin LR is shown to be the only viable method for very large datasets.
  • Limitations: The linearity relies on the assumption that the intrinsic dimensionality of the data (reflected in the rank of cost matrices) and the required rank rr for the coupling are small relative to nn. Tuning γ\gamma might be necessary in practice.

Overall, the paper presents a significant step towards making Gromov-Wasserstein scalable by introducing and demonstrating the effectiveness of low-rank approaches for both the geometry of the input spaces and the structure of the coupling. The combined linear-time method offers a practical way to apply GW to large-scale problems previously inaccessible due to computational constraints.