- The paper introduces Polar Express, a novel iterative method for computing matrix polar decomposition using optimally selected polynomial compositions for efficient, low-precision application in deep learning optimizers.
- Polar Express achieves rapid convergence, crucial for deep learning, by optimally selecting a unique polynomial for each iterative step to minimize worst-case spectral norm error.
- Integrating Polar Express into the Muon optimizer demonstrates improved training loss and faster convergence compared to alternative methods on GPT-2 model training experiments.
Computing the polar decomposition or matrix sign function of gradients is a crucial subroutine for certain advanced deep learning optimizers, notably Muon [jordan2024muon, bernstein2024oldoptimizernewnorm]. Unlike traditional numerical analysis applications where high accuracy is paramount, deep learning requires methods that are highly efficient, GPU-compatible, and can operate effectively in low precision like bfloat16, often sacrificing some accuracy for speed. Classical methods like Newton-Schulz suffer from slow initial convergence, while rational function methods require computationally expensive operations like QR decompositions or matrix inverses, which are not well-suited for current GPU architectures.
The paper "The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm" (2505.16932) introduces a novel iterative method, Polar Express, designed specifically for this deep learning context. It computes an approximation to the polar decomposition $\polar(M)$ for a rectangular gradient matrix M.
The Muon Method
The @@@@2@@@@ updates neural network weights Wt at iteration t based on the gradient momentum estimate Mt=βMt−1+(1−β)Gt using the rule:
$\mW_{t+1} = \mW_{t}- \lambda \polar(M_t)$, where λ is the learning rate.
This update takes a step in the direction of $-\polar(M_t)$. If $M = \mU \mSigma V^T$ is the singular value decomposition (SVD), $\polar(M) = \mU V^T$. This effectively aligns the update direction with the orthogonal component of the momentum matrix, corresponding to the steepest descent direction in the spectral norm. Efficiently computing $\polar(M_t)$ at each step is key to making Muon practical.
Polar Express Approach
Polar Express computes $\polar(M)$ using an iterative method based on compositions of odd polynomials. An odd polynomial p(x)=a0x+a1x3+⋯+aqx2q+1 can be applied to a matrix M as p(M)=a0M+a1M(MTM)+⋯+aqM(MTM)q. This only requires matrix-matrix multiplications (MTM, then products with M), which are highly efficient on GPUs.
The method constructs a sequence of approximations Xt via X0=M (normalized) and Xt=pt(Xt−1) for t=1,…,T, where each pt is an odd polynomial of fixed degree d. The final result is XT=(pT∘⋯∘p1)(M). This iterative structure allows applying a high-degree composite polynomial p=pT∘⋯∘p1 (with degree dT) using only O(Td) matrix multiplications, as opposed to O(dT) for direct application.
The key innovation in Polar Express is the selection of the polynomials pt. Instead of using the same polynomial repeatedly (like in Newton-Schulz), each pt is chosen optimally at each iteration to minimize the worst-case error in the spectral norm, given the current approximation Xt−1. This translates to minimizing maxx∈[ℓt,ut]∣1−pt(x)∣, where [ℓt,ut] is the current range of singular values of Xt−1 (assuming σ(Xt−1)⊂[ℓt,ut] and $\polar(\bm{X}_{t-1})$ corresponds to mapping singular values to 1).
This greedy approach for choosing pt is proven to yield a composite polynomial p⋆=pT∘⋯∘p1 that is optimal for approximating the function x↦1 over the initial singular value range [ℓ,u] in the minimax sense, i.e., maxx∈[ℓ,u]∣1−p⋆(x)∣ is minimized among all compositions of T odd polynomials of degree d. The error after T iterations is bounded by 1−ℓT+1, where ℓt+1=pt(ℓt) and ut+1=2−ℓt+1 define the evolution of the singular value bounds. This optimal selection strategy ensures rapid convergence, particularly in the initial iterations, which is crucial for deep learning where few iterations are performed. The method retains faster-than-exponential asymptotic convergence (quadratic for d=3, cubic for d=5).
Implementation Details
The Polar Express algorithm consists of two stages:
- Offline Stage: The sequence of optimal polynomials p1,…,pT is computed once and stored. This involves solving the minimax approximation problem $\argmin_{p \in \mathbb{P}_d^{\odd}} \max_{x \in [\ell_t, u_t]} |1 - p(x)|$ for each t, given [ℓt,ut].
- For d=3, there is a known closed-form solution based on a scaled Newton-Schulz polynomial [chen2014stable].
- For d=5 (recommended degree), the polynomials are computed using the Remez algorithm, which iteratively finds the polynomial satisfying the equioscillation property on [ℓt,ut]. This process is fast and accurate even for small d.
- The initial singular value range [ℓ1,u1] is needed. The input matrix M is normalized by its Frobenius norm ∥M∥F, setting u1=1. A practical lower bound ℓ1 (e.g., 10−3 for bfloat16) is chosen as estimating the true σmin(M) is expensive. A slightly incorrect ℓ1 only marginally impacts convergence.
- The offline stage is typically performed in high precision (float64).
- Online Stage: The precomputed polynomials p1,…,pT are applied iteratively to the normalized input matrix X0=M/(∥M∥F+10−7) using Horner's rule for polynomial evaluation. This stage is performed in low precision (bfloat16) on the GPU.
Finite Precision Stability: Operating in low precision like bfloat16 requires specific adjustments:
- Safety Factor: Each polynomial pt(x) is modified to pt(x/1.01) to handle potential round-off errors causing singular values to exceed ut. This causes convergence to a value slightly less than 1, correctable in the final iteration.
- Cushioning: When ℓt is very small compared to ut, the optimal polynomial might map some values in [ℓt,ut] non-monotonically (decreasing values near ut). To prevent loss of precision, the Remez algorithm is run on a slightly adjusted interval [max(ℓt,ut/10),ut] if ℓt<ut/10, ensuring xpt(x) doesn't become too small [doi:10.1137/110857544, chen2014stable].
A reference Python implementation for computing the coefficients of the optimal degree-5 polynomials for a given ℓ1 and number of iterations T is provided in the paper. For Muon applications, using d=5 and T=5 or $6$ with ℓ1=10−3 is recommended. The coefficients derived from this offline stage are then hardcoded for the online GPU computation.
Fast Application for Rectangular Matrices: For rectangular matrices M∈Rm×n with m≫n (large aspect ratio α=m/n), applying pt(Xt−1) as Xt−1ht(Xt−1TXt−1) (where pt(x)=xht(x2)) still involves expensive m×n by n×n matrix multiplications (Xt−1ht(…)) in each iteration. A faster application method is possible when α is sufficiently large (e.g., α>1.5 for T=6). This method computes Y=X0TX0 once (mn2 operations) and then iteratively computes Qt=Qt−1ht(Rt) where Rt=Qt−1TYQt−1. This involves only n×n matrix multiplications within the loop. The final result is X0QT. The total cost is dominated by the initial X0TX0 and final X0QT multiplications (2mn2) and the T small matrix polynomial evaluations (O(Tdn3)), significantly reducing the cost compared to the naive method when m≫n. While faster on GPUs, this method can introduce numerical stability issues in bfloat16 for large T, potentially requiring restarting the process every few iterations to maintain stability.
Experimental Results:
Numerical experiments demonstrate the effectiveness of Polar Express:
- Convergence: On synthetic matrices and actual GPT-2 gradient matrices, Polar Express consistently achieves lower spectral norm error than Newton-Schulz, Jordan's method, and You's method for a given number of iterations/matrix multiplications, especially in the crucial initial phase (first 5-10 iterations).
- GPT-2 Training: When integrated into the Muon optimizer for training a 124M parameter GPT-2 model on the FineWeb dataset, muon-Polar Express achieves better final validation and training loss compared to muon-Jordan and muon-You across a range of learning rates. Since all tested matrix sign methods used 5 iterations of a degree-5 polynomial, the computational cost per Muon step was similar, translating the convergence advantage directly into wall-clock time savings.
The code snippet for computing the optimal degree-5 polynomial coefficients for the first few iterations is provided in the paper, allowing straightforward implementation of the offline stage.
In summary, Polar Express offers a theoretically grounded, empirically validated approach to computing the polar decomposition for deep learning optimizers like Muon, leveraging GPU-friendly polynomial iterations and optimizing them for rapid initial convergence in low precision.