Papers
Topics
Authors
Recent
Search
2000 character limit reached

The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm

Published 22 May 2025 in cs.LG, cs.AI, cs.CL, cs.NA, math.NA, and math.OC | (2505.16932v2)

Abstract: Computing the polar decomposition and the related matrix sign function, has been a well-studied problem in numerical analysis for decades. More recently, it has emerged as an important subroutine in deep learning, particularly within the Muon optimization framework. However, the requirements in this setting differ significantly from those of traditional numerical analysis. In deep learning, methods must be highly efficient and GPU-compatible, but high accuracy is often unnecessary. As a result, classical algorithms like Newton-Schulz (which suffers from slow initial convergence) and methods based on rational functions (which rely on QR decompositions or matrix inverses) are poorly suited to this context. In this work, we introduce Polar Express, a GPU-friendly algorithm for computing the polar decomposition. Like classical polynomial methods such as Newton-Schulz, our approach uses only matrix-matrix multiplications, making it GPU-compatible. Motivated by earlier work of Chen & Chow and Nakatsukasa & Freund, Polar Express adapts the polynomial update rule at each iteration by solving a minimax optimization problem, and we prove that it enjoys a strong worst-case optimality guarantee. This property ensures both rapid early convergence and fast asymptotic convergence. We also address finite-precision issues, making it stable in bfloat16 in practice. We apply Polar Express within the Muon optimization framework and show consistent improvements in validation loss on large-scale models such as GPT-2, outperforming recent alternatives across a range of learning rates.

Summary

  • The paper introduces Polar Express, a novel iterative method for computing matrix polar decomposition using optimally selected polynomial compositions for efficient, low-precision application in deep learning optimizers.
  • Polar Express achieves rapid convergence, crucial for deep learning, by optimally selecting a unique polynomial for each iterative step to minimize worst-case spectral norm error.
  • Integrating Polar Express into the Muon optimizer demonstrates improved training loss and faster convergence compared to alternative methods on GPT-2 model training experiments.

Computing the polar decomposition or matrix sign function of gradients is a crucial subroutine for certain advanced deep learning optimizers, notably Muon [jordan2024muon, bernstein2024oldoptimizernewnorm]. Unlike traditional numerical analysis applications where high accuracy is paramount, deep learning requires methods that are highly efficient, GPU-compatible, and can operate effectively in low precision like bfloat16, often sacrificing some accuracy for speed. Classical methods like Newton-Schulz suffer from slow initial convergence, while rational function methods require computationally expensive operations like QR decompositions or matrix inverses, which are not well-suited for current GPU architectures.

The paper "The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm" (2505.16932) introduces a novel iterative method, Polar Express, designed specifically for this deep learning context. It computes an approximation to the polar decomposition $\polar(M)$ for a rectangular gradient matrix MM.

The Muon Method

The @@@@2@@@@ updates neural network weights WtW_t at iteration tt based on the gradient momentum estimate Mt=βMt1+(1β)GtM_t = \beta M_{t-1} + (1-\beta) G_t using the rule: $\mW_{t+1} = \mW_{t}- \lambda \polar(M_t)$, where λ\lambda is the learning rate. This update takes a step in the direction of $-\polar(M_t)$. If $M = \mU \mSigma V^T$ is the singular value decomposition (SVD), $\polar(M) = \mU V^T$. This effectively aligns the update direction with the orthogonal component of the momentum matrix, corresponding to the steepest descent direction in the spectral norm. Efficiently computing $\polar(M_t)$ at each step is key to making Muon practical.

Polar Express Approach

Polar Express computes $\polar(M)$ using an iterative method based on compositions of odd polynomials. An odd polynomial p(x)=a0x+a1x3++aqx2q+1p(x) = a_0 x + a_1 x^3 + \cdots + a_q x^{2q+1} can be applied to a matrix MM as p(M)=a0M+a1M(MTM)++aqM(MTM)qp(M) = a_0 M + a_1 M(M^T M) + \cdots + a_q M(M^T M)^q. This only requires matrix-matrix multiplications (MTMM^T M, then products with MM), which are highly efficient on GPUs.

The method constructs a sequence of approximations Xt\bm{X}_t via X0=M\bm{X}_0 = M (normalized) and Xt=pt(Xt1)\bm{X}_t = p_t(\bm{X}_{t-1}) for t=1,,Tt=1, \dots, T, where each ptp_t is an odd polynomial of fixed degree dd. The final result is XT=(pTp1)(M)\bm{X}_T = (p_T \circ \cdots \circ p_1)(M). This iterative structure allows applying a high-degree composite polynomial p=pTp1p = p_T \circ \cdots \circ p_1 (with degree dTd^T) using only O(Td)O(Td) matrix multiplications, as opposed to O(dT)O(d^T) for direct application.

The key innovation in Polar Express is the selection of the polynomials ptp_t. Instead of using the same polynomial repeatedly (like in Newton-Schulz), each ptp_t is chosen optimally at each iteration to minimize the worst-case error in the spectral norm, given the current approximation Xt1\bm{X}_{t-1}. This translates to minimizing maxx[t,ut]1pt(x)\max_{x \in [\ell_t, u_t]} |1 - p_t(x)|, where [t,ut][\ell_t, u_t] is the current range of singular values of Xt1\bm{X}_{t-1} (assuming σ(Xt1)[t,ut]\sigma(\bm{X}_{t-1}) \subset [\ell_t, u_t] and $\polar(\bm{X}_{t-1})$ corresponds to mapping singular values to 1).

This greedy approach for choosing ptp_t is proven to yield a composite polynomial p=pTp1p^{\star} = p_T \circ \cdots \circ p_1 that is optimal for approximating the function x1x \mapsto 1 over the initial singular value range [,u][\ell, u] in the minimax sense, i.e., maxx[,u]1p(x)\max_{x \in [\ell, u]} |1 - p^{\star}(x)| is minimized among all compositions of TT odd polynomials of degree dd. The error after TT iterations is bounded by 1T+11 - \ell_{T+1}, where t+1=pt(t)\ell_{t+1} = p_t(\ell_t) and ut+1=2t+1u_{t+1} = 2 - \ell_{t+1} define the evolution of the singular value bounds. This optimal selection strategy ensures rapid convergence, particularly in the initial iterations, which is crucial for deep learning where few iterations are performed. The method retains faster-than-exponential asymptotic convergence (quadratic for d=3d=3, cubic for d=5d=5).

Implementation Details

The Polar Express algorithm consists of two stages:

  1. Offline Stage: The sequence of optimal polynomials p1,,pTp_1, \dots, p_T is computed once and stored. This involves solving the minimax approximation problem $\argmin_{p \in \mathbb{P}_d^{\odd}} \max_{x \in [\ell_t, u_t]} |1 - p(x)|$ for each tt, given [t,ut][\ell_t, u_t].
    • For d=3d=3, there is a known closed-form solution based on a scaled Newton-Schulz polynomial [chen2014stable].
    • For d=5d=5 (recommended degree), the polynomials are computed using the Remez algorithm, which iteratively finds the polynomial satisfying the equioscillation property on [t,ut][\ell_t, u_t]. This process is fast and accurate even for small dd.
    • The initial singular value range [1,u1][\ell_1, u_1] is needed. The input matrix MM is normalized by its Frobenius norm MF\|M\|_F, setting u1=1u_1=1. A practical lower bound 1\ell_1 (e.g., 10310^{-3} for bfloat16) is chosen as estimating the true σmin(M)\sigma_{\min}(M) is expensive. A slightly incorrect 1\ell_1 only marginally impacts convergence.
    • The offline stage is typically performed in high precision (float64).
  2. Online Stage: The precomputed polynomials p1,,pTp_1, \dots, p_T are applied iteratively to the normalized input matrix X0=M/(MF+107)\bm{X}_0 = M / (\|M\|_F + 10^{-7}) using Horner's rule for polynomial evaluation. This stage is performed in low precision (bfloat16) on the GPU.

Finite Precision Stability: Operating in low precision like bfloat16 requires specific adjustments:

  • Safety Factor: Each polynomial pt(x)p_t(x) is modified to pt(x/1.01)p_t(x/1.01) to handle potential round-off errors causing singular values to exceed utu_t. This causes convergence to a value slightly less than 1, correctable in the final iteration.
  • Cushioning: When t\ell_t is very small compared to utu_t, the optimal polynomial might map some values in [t,ut][\ell_t, u_t] non-monotonically (decreasing values near utu_t). To prevent loss of precision, the Remez algorithm is run on a slightly adjusted interval [max(t,ut/10),ut][\max(\ell_t, u_t/10), u_t] if t<ut/10\ell_t < u_t/10, ensuring pt(x)x\frac{p_t(x)}{x} doesn't become too small [doi:10.1137/110857544, chen2014stable].

A reference Python implementation for computing the coefficients of the optimal degree-5 polynomials for a given 1\ell_1 and number of iterations TT is provided in the paper. For Muon applications, using d=5d=5 and T=5T=5 or $6$ with 1=103\ell_1=10^{-3} is recommended. The coefficients derived from this offline stage are then hardcoded for the online GPU computation.

Fast Application for Rectangular Matrices: For rectangular matrices MRm×nM \in \mathbb{R}^{m \times n} with mnm \gg n (large aspect ratio α=m/n\alpha = m/n), applying pt(Xt1)p_t(\bm{X}_{t-1}) as Xt1ht(Xt1TXt1)\bm{X}_{t-1} h_t(\bm{X}_{t-1}^T \bm{X}_{t-1}) (where pt(x)=xht(x2)p_t(x) = x h_t(x^2)) still involves expensive m×nm \times n by n×nn \times n matrix multiplications (Xt1ht()\bm{X}_{t-1} h_t(\dots)) in each iteration. A faster application method is possible when α\alpha is sufficiently large (e.g., α>1.5\alpha > 1.5 for T=6T=6). This method computes Y=X0TX0\bm{Y} = \bm{X}_0^T \bm{X}_0 once (mn2mn^2 operations) and then iteratively computes Qt=Qt1ht(Rt)\bm{Q}_t = \bm{Q}_{t-1} h_t(\bm{R}_t) where Rt=Qt1TYQt1\bm{R}_t = \bm{Q}_{t-1}^T \bm{Y} \bm{Q}_{t-1}. This involves only n×nn \times n matrix multiplications within the loop. The final result is X0QT\bm{X}_0 \bm{Q}_T. The total cost is dominated by the initial X0TX0\bm{X}_0^T \bm{X}_0 and final X0QT\bm{X}_0 \bm{Q}_T multiplications (2mn22mn^2) and the TT small matrix polynomial evaluations (O(Tdn3)O(T d n^3)), significantly reducing the cost compared to the naive method when mnm \gg n. While faster on GPUs, this method can introduce numerical stability issues in bfloat16 for large TT, potentially requiring restarting the process every few iterations to maintain stability.

Experimental Results:

Numerical experiments demonstrate the effectiveness of Polar Express:

  • Convergence: On synthetic matrices and actual GPT-2 gradient matrices, Polar Express consistently achieves lower spectral norm error than Newton-Schulz, Jordan's method, and You's method for a given number of iterations/matrix multiplications, especially in the crucial initial phase (first 5-10 iterations).
  • GPT-2 Training: When integrated into the Muon optimizer for training a 124M parameter GPT-2 model on the FineWeb dataset, muon-Polar Express achieves better final validation and training loss compared to muon-Jordan and muon-You across a range of learning rates. Since all tested matrix sign methods used 5 iterations of a degree-5 polynomial, the computational cost per Muon step was similar, translating the convergence advantage directly into wall-clock time savings.

The code snippet for computing the optimal degree-5 polynomial coefficients for the first few iterations is provided in the paper, allowing straightforward implementation of the offline stage.

In summary, Polar Express offers a theoretically grounded, empirically validated approach to computing the polar decomposition for deep learning optimizers like Muon, leveraging GPU-friendly polynomial iterations and optimizing them for rapid initial convergence in low precision.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 16 tweets with 771 likes about this paper.