- The paper introduces a dynamical systems analysis of learned optimizers using a noisy quadratic model to assess stability and meta-training challenges.
- It proposes specific modifications—incorporating a nominal optimizer, heavy weight decay, and output preconditioning—to improve stability and robustness.
- Experiments with the STAR optimizer demonstrate faster meta-training, better final performance, and strong generalization across diverse models and tasks.
This paper investigates the stability and generalization problems often encountered with learned optimizers (LOs), which are neural networks trained to perform optimization (A Closer Look at Learned Optimization: Stability, Robustness, and Inductive Biases, 2022). While LOs can potentially accelerate machine learning model training, they frequently become unstable or perform poorly when applied to tasks or for training durations different from those they were meta-trained on.
The authors use tools from dynamical systems theory, specifically analyzing the optimization process in a noisy quadratic model (NQM), to understand the stability properties of LOs. The NQM involves minimizing a quadratic loss function L(ϕt)=21(ϕt−μt)⊤H(ϕt−μt) where the minimum μt is drawn i.i.d. from a distribution (e.g., N(0,Σμ)) at each step t. This setup models stochastic optimization with minibatches.
In the NQM, the parameter update is modeled as ϕt+1=ϕt−(α∇t+P∇t), where α∇t is a "nominal" hand-designed optimizer step (like scaled gradient descent) and P∇t is the output of a linear learned optimizer represented by matrix P. This leads to linear dynamics:
ϕt+1=(I−(αI+P)H)ϕt+(αI+P)Hμt.
Stability is determined by the eigenvalues of the dynamics matrix A=I−(αI+P)H. The system is stable if the spectral radius ρ(A)=maxi∣λi(A)∣<1. Instability (ρ(A)≥1) leads to diverging losses and unstable meta-gradients, hindering meta-training.
Based on this analysis, the paper proposes several modifications to improve LO stability and inductive bias:
- Nominal Optimizer Term: Incorporating a hand-designed optimizer component (like Adam or AggMo) ensures a baseline descent direction, improving stability, especially early in meta-training. An additional learned magnitude controller modulates this nominal term.
- Implementation: Add fg(zt)=β1exp(β2mg(zt))g(zt) to the update, where g(zt) is the nominal update (e.g., Adam step) and mg(zt) is a learned magnitude output.
- Heavy Weight Decay: Applying strong L2 regularization (L2) to the LO's parameters during meta-training discourages large outputs from the learned component, pulling eigenvalues towards stability.
- Implementation: Use AdamW meta-optimizer with a non-zero weight decay hyperparameter on the LO's network weights.
- Output Preconditioning: Normalizing the output of the learned component using an adaptive preconditioner (similar to Adam's normalization) makes the update magnitude less dependent on the problem's Hessian and improves robustness.
- Implementation: Modify the blackbox term to fb(zt)=β3v(zt)d(zt)exp(β4mb(zt)), where v(zt) is a preconditioner term (e.g., RMS of gradients like in Adam).
- Stable Hidden States: Using stable update rules (like EMA) for the LO's internal state prevents internal dynamics from causing instability.
These modifications are incorporated into an existing efficient, elementwise MLP-based learned optimizer (small_fc_lopt
) resulting in the "Stabilized Through Ample Regularization" (STAR) optimizer. The STAR optimizer adds only a few parameters (for the nominal magnitude controller) compared to the baseline.
Experiments show that STAR:
- Meta-trains faster and achieves better final performance on meta-training tasks (MLP on Fashion MNIST, CNN on CIFAR10) compared to the purely blackbox baseline and a hyperparameter-controller variant.
- Remains stable and performs well when run for significantly more steps (e.g., 10k steps) than used during meta-training (2k steps), unlike the baseline blackbox LO which often diverges.
- Generalizes remarkably well to diverse, unseen tasks (different architectures like ResNet, LSTM, Transformer; different datasets like ImageNet, LM1B) even when only meta-trained on a small MLP/Fashion MNIST task. It often matches or outperforms tuned Adam on these tasks, while the baseline blackbox LO diverges.
The key takeaway is that explicitly incorporating stability-promoting inductive biases, guided by dynamical systems analysis, significantly improves the robustness, performance, and generalization capabilities of learned optimizers, making them more practical for real-world applications.