Linear-Time Gromov Wasserstein Distances using Low Rank Couplings and Costs (2106.01128v2)
Abstract: The ability to align points across two related yet incomparable point clouds (e.g. living in different spaces) plays an important role in machine learning. The Gromov-Wasserstein (GW) framework provides an increasingly popular answer to such problems, by seeking a low-distortion, geometry-preserving assignment between these points. As a non-convex, quadratic generalization of optimal transport (OT), GW is NP-hard. While practitioners often resort to solving GW approximately as a nested sequence of entropy-regularized OT problems, the cubic complexity (in the number $n$ of samples) of that approach is a roadblock. We show in this work how a recent variant of the OT problem that restricts the set of admissible couplings to those having a low-rank factorization is remarkably well suited to the resolution of GW: when applied to GW, we show that this approach is not only able to compute a stationary point of the GW problem in time $O(n2)$, but also uniquely positioned to benefit from the knowledge that the initial cost matrices are low-rank, to yield a linear time $O(n)$ GW approximation. Our approach yields similar results, yet orders of magnitude faster computation than the SoTA entropic GW approaches, on both simulated and real data.
Summary
- The paper introduces a novel linear-time algorithm for Gromov-Wasserstein distances by exploiting low-rank factorizations in both cost matrices and coupling structures.
- It reduces the cubic complexity of traditional methods to linear complexity by approximating input costs and constraining the coupling matrix.
- The approach demonstrates practical scalability on large-scale datasets such as single-cell genomics and human brain data while preserving alignment accuracy.
The paper "Linear-Time Gromov Wasserstein Distances using Low Rank Couplings and Costs" (Linear-Time Gromov Wasserstein Distances using Low Rank Couplings and Costs, 2021) addresses the computational bottleneck of the Gromov-Wasserstein (GW) problem, which is widely used for aligning and comparing data from different metric spaces, such as point clouds or distributions living in heterogeneous feature spaces. The standard approach to solving GW approximately, based on iteratively solving entropy-regularized Optimal Transport (OT) problems, suffers from cubic complexity O(n3) in the number of samples n, making it impractical for large datasets. This paper proposes novel methods to reduce the GW computation time, ultimately achieving a linear-time algorithm by exploiting low-rank structures in both the input cost matrices and the coupling matrix.
The standard GW problem seeks a coupling matrix P∈Rn×m between two discrete measures with n and m samples, represented by cost matrices A∈Rn×n and B∈Rm×m that encode the geometry within each space. The objective is to minimize a quadratic function QA,B(P) that measures the distortion introduced by the coupling P. The standard entropic GW approximation solves this non-convex problem iteratively using Mirror Descent, which boils down to solving a sequence of entropy-regularized OT problems with a synthetic cost matrix Ct=−4APt−1B. The computational bottleneck arises from two main steps in each iteration:
- Updating the cost matrix Ct=−4APt−1B, which requires O(n2m+nm2) operations.
- Evaluating the GW objective QA,B(P), which also costs O(n2m+nm2) naively, but can be computed more efficiently using a reformulation in O(n2m+nm2) operations.
- Solving the entropy-regularized OT problem P∈Πa,bmin⟨Ct,P⟩−εH(P) using Sinkhorn's algorithm takes O(nm) operations per iteration of Sinkhorn, but the dominant cost per outer GW iteration remains the O(n2m+nm2) cost of updating Ct.
The paper tackles this cubic complexity by introducing two independent strategies and then showing how to combine them:
1. Low-rank (Approximated) Costs:
If the input cost matrices A and B admit low-rank factorizations, A=A1A2T and B=B1B2T where A1,A2∈Rn×d and B1,B2∈Rm×d′ with d≪n and d′≪m, the complexity of updating the synthetic cost matrix Ct can be reduced. Ct=−4A1A2TPt−1B1B2T. Computing this product can be done more efficiently: first compute G=A2TPt−1B1 in O(nmd+mdd′) operations, then Ct=−4A1GB2T in O(ndd′+nmd′). If d,d′ are small constants, this reduces the update cost to O(nm(d+d′)), which is O(nm) per outer iteration if n≈m, or O(n2) if n=m. Similarly, the evaluation of QA,B(P) can be sped up to O(nmd+mdd′). This strategy is particularly relevant for squared Euclidean distance matrices, where an exact low-rank factorization exists with rank related to the ambient dimension (d+2). For general distance matrices, recent work allows for computing low-rank approximations in nearly linear time, enabling this speedup even when an exact factorization is not obvious. Algorithm 2 outlines this "Quadratic Entropic-GW" approach.
2. Low-rank Constraints for Couplings:
Instead of assuming low-rank input costs, the paper proposes constraining the coupling matrix P to have a low nonnegative rank. This is achieved by restricting P to the form P=QDiag(1/g)RT, where Q∈Rn×r,R∈Rm×r are matrices satisfying certain marginal constraints (specifically, Q1r=a, R1r=b, QT1n=RT1m=g) and g∈Δr is a common intermediate marginal. This factorization implies P has a nonnegative rank at most r. The GW problem is then reformulated as minimizing QA,B(QDiag(1/g)RT) over (Q,R,g) in the feasible set C(a,b,r). This problem is solved using a Mirror Descent scheme w.r.t. the KL divergence in the space of (Q,R,g). Each step involves computing generalized kernel matrices K(1),K(2),K(3) and then solving a barycenter problem efficiently using Dykstra's algorithm (Algorithm 3). The initialization uses a low-rank approximation of a lower bound based on the squared norms of the rows/columns of A⊙2a and B⊙2b. This initialization itself can be computed efficiently. While Dykstra's algorithm for the barycenter step takes O((n+m)r) operations per inner iteration, computing the kernel matrices K(1),K(2),K(3) still involves matrix products like APkB, which require O(n2m+nm2) operations in the general case. Thus, this approach alone reduces the complexity per outer iteration to O((n2+m2)r).
3. Double Low-rank GW:
The key contribution is showing that combining both low-rank strategies yields a linear-time algorithm. If both cost matrices A,B have low-rank factorizations (A=A1A2T,B=B1B2T) and the coupling P is constrained to be low-rank (P=QDiag(1/g)RT), the critical computation APB becomes A1A2TQDiag(1/g)RTB1B2T. This can be computed in O(nrd+mrd′+r2dd′) operations by first computing A2TQ (in O(nrd)), then (A2TQ)Diag(1/g) (in O(nr)), then ((A2TQ)Diag(1/g))RT (in O(nrm)), etc. More strategically, one can compute intermediate low-rank factors of APB: C1=−A1(A2TQ)Diag(1/g) and C2=(RTB1)B2T. The necessary terms for the kernel matrices can then be computed from these factors efficiently. Specifically, terms like APkBRkDiag(1/gk) can be computed efficiently under these double low-rank assumptions. For instance, (APkB)RkDiag(1/gk)=A1(A2TQk)Diag(1/gk)(RkTB1)B2TRkDiag(1/gk). This still seems complex, but the Mirror Descent updates in the low-rank coupling formulation (Algorithm 3) involve terms like APkBRkDiag(1/gk) (for K(1)) and BPkTAQkDiag(1/gk) (for K(2)), and D(QkTAPkBRk) (for K(3)). Using Pk=QkDiag(1/gk)RkT, these terms become:
- AQkDiag(1/gk)RkTBRkDiag(1/gk)
- BRkDiag(1/gk)QkTAQkDiag(1/gk)
- D(QkTAQkDiag(1/gk)RkTBRk) If A=A1A2T and B=B1B2T, these can be computed in O(n(rd+r2)+m(rd′+r2)+r2(d+d′)) operations per outer iteration. For instance, A1(A2TQk)Diag(1/gk)(RkTB1)B2TRkDiag(1/gk):
- A2TQk: O(nrd)
- (A2TQk)Diag(1/gk): O(nr)
- RkTB1: O(mrd′)
- Product of (n×r) and (r×r) and (r×m) involves intermediate O(nr2+r2m) operations, leading to O(nmd) overall for APB. This was not the linearization.
The linear time comes from carefully re-evaluating the gradients in the Mirror Descent update (Equation 6 in the paper): The gradient w.r.t Q involves APBRDiag(1/g). With factorizations: A1A2TQDiag(1/g)RTB1B2TRDiag(1/g). This can be computed as A1(A2TQDiag(1/g)(RTB1)(B2TR)Diag(1/g)). The term in parentheses is r×r. A2TQ takes O(nrd). RTB1 takes O(mrd′). B2TR takes O(mrd′). A2TQ is d×r. (A2TQ)Diag(1/g) is d×r. (RTB1) is r×d′. (B2TR) is d′×r. The term M=A2TQ is d×r, N=RTB1 is r×d′, O=B2TR is d′×r. The inner parentheses computation is MDiag(1/g)NODiag(1/g). This is a product of matrices of sizes (d×r), (r×r), (r×d′), (d′×r), (r×r). The central term (r×d′)×(d′×r) is r×r. (r×r)×(r×r) is r×r. (d×r)×(r×r) is d×r. Final matrix multiplication involves (d×r) with Diag(1/g) and A1(n×d). The most expensive matrix products are:
- A2TQ: O(nrd)
- RTB1: O(mrd′)
- B2TR: O(mrd′)
- MDiag(1/g)NODiag(1/g): inner products NO take O(rd′r), then products M(…)Diag(1/g) take O(dr2+d′r2). Total O(rd′r+dr2+d′r2).
- Outer product with A1: A1(d×r). This is n×r, cost O(ndr). So the gradient w.r.t Q computation is dominated by O(nrd+mrd′+rd′r+dr2+d′r2). If r,d,d′ are constants, this is O(n+m). If r,d,d′ are O(logn) or O(poly(logn)), it's still close to linear. Similarly, evaluating QA,B(P)=⟨A⊙2a,a⟩+⟨B⊙2b,b⟩−2⟨APB,P⟩. ⟨APB,P⟩=⟨A1A2TQDiag(1/g)RTB1B2T,QDiag(1/g)RT⟩. This can be computed efficiently by recognizing it's related to G1⊙G2 terms, where G1=A1TPB2 and G2=A2TPB1. A1TQDiag(1/g)RTB2 (d×d′) and A2TQDiag(1/g)RTB1 (d×d′). These can be computed in O(nrd+mrd′). The dot product 1dT(G1⊙G2)1d′ is O(dd′). The terms ⟨A⊙2a,a⟩ and ⟨B⊙2b,b⟩ can also be computed in nearly linear time if A and B are factorized as A=A1A2T etc., exploiting the fact that A⊙2 can also be factorized, though with potentially larger rank d2. Computing x~=A⊙2a costs O(nd2). Thus, under both low-rank assumptions, the computation per outer iteration becomes linear O((n+m)(r+d+d′)+r2(d+d′)+dd′r). With r,d,d′ small, this is O(n+m).
Implementation and Applications:
The paper provides algorithms for the quadratic (Algorithm 2, 3) and linear (Section 5 combining aspects of 2 and 3) methods.
- Initialization: A warm start is crucial for non-convex optimization. The proposed initialization uses a low-rank OT problem based on the squared norms of the row/column norms of A⊙2a and B⊙2b, which can be computed in linear time under low-rank cost assumptions.
- Optimization: Mirror Descent is used. For the low-rank coupling method, each MD step requires solving a barycenter problem using Dykstra's algorithm. The paper shows experimentally that the number of Dykstra iterations doesn't heavily depend on n, which is favorable.
- Hyperparameters: The method has hyperparameters like the low rank r and the step size γ (or regularization ε if double regularization is used). The paper explores the sensitivity to γ and a lower bound α on entries of g, finding the method relatively robust. The choice of r affects the quality of the approximation; ideally, it should relate to the intrinsic dimension or number of clusters in the data.
- Computational Cost: The paper provides clear complexity analyses: O(n3) for standard entropic GW, O(n2) for quadratic GW (low-rank costs or low-rank couplings separately), and O(n) for linear GW (both low-rank costs and couplings).
- Real-world Applications: The methods are demonstrated on single-cell genomics data (SNAREseq and Splatter) and a human brain dataset (BRAIN). These applications involve aligning point clouds representing cells characterized by different molecular features (e.g., gene expression and chromatin accessibility). The distance metric used is often based on k-NN graphs and shortest paths. While shortest path distance matrices don't automatically admit low-rank factorizations like Euclidean distance, the quadratic version of the algorithm is applicable by simply computing the full distance matrix. The linear version is demonstrated on the BRAIN dataset using squared Euclidean distance after PCA, where the low-rank factorization is available.
- Performance: Experiments show that the proposed LR (quadratic) and Lin LR (linear) methods achieve similar GW loss and downstream task performance (like cell type alignment measured by FOSCTTM) compared to the standard Entropic-GW and MREC baselines, but are orders of magnitude faster, particularly at large scales (n=m=105). Lin LR is shown to be the only viable method for very large datasets.
- Limitations: The linearity relies on the assumption that the intrinsic dimensionality of the data (reflected in the rank of cost matrices) and the required rank r for the coupling are small relative to n. Tuning γ might be necessary in practice.
Overall, the paper presents a significant step towards making Gromov-Wasserstein scalable by introducing and demonstrating the effectiveness of low-rank approaches for both the geometry of the input spaces and the structure of the coupling. The combined linear-time method offers a practical way to apply GW to large-scale problems previously inaccessible due to computational constraints.
Related Papers
- Semidefinite Relaxations of the Gromov-Wasserstein Distance (2023)
- Sliced Gromov-Wasserstein (2019)
- Efficient Approximation of Gromov-Wasserstein Distance Using Importance Sparsification (2022)
- Quantized Gromov-Wasserstein (2021)
- Tangential Fixpoint Iterations for Gromov-Wasserstein Barycenters (2024)