Dice Question Streamline Icon: https://streamlinehq.com

Theoretical justification for the new embedding learning-rate scaling rule

Establish a theoretical justification for scaling the embedding layer learning rate by 1/sqrt(fan-in) in the Unit-Scaled Maximal Update Parametrization (u), clarifying why this rule should replace the Maximal Update Parametrization’s constant-width embedding learning-rate scaling and under what assumptions it ensures improved hyperparameter transfer across width.

Information Square Streamline Icon: https://streamlinehq.com

Background

The authors identify that the embedding learning-rate multiplier used in the Maximal Update Parametrization (µP) exhibits poor transfer across width for Llama-style transformer LLMs. They attribute this to a mis-specified scaling rule in µP, which keeps the embedding learning rate constant with width.

To remedy this, they propose a new scaling rule for the embedding learning rate, setting c_emb = 1/sqrt(fan-in), and show empirically that this rule improves transfer and large-width performance. However, they explicitly state that they do not provide a theoretical justification for this change and leave it to future work, making the derivation of such a justification an open problem.

References

We offer no theoretical justification for our rule, which we leave to further work.

u-$μ$P: The Unit-Scaled Maximal Update Parametrization (2407.17465 - Blake et al., 24 Jul 2024) in Section “A new embedding LR rule” (sec:umup:emb_lr_rule)