- The paper identifies growth in linear layer outputs (QKV, Proj, FC2) as a key cause of LLM training instability and analyzes various methods to address it.
- The study evaluates layer normalization configurations (post-QKV, QK FC norm), softmax adjustments (temperature, capping), and alternative techniques (LayerScale, QK norm).
- Experiments show that QKV norm and QK norm cap configurations allow significantly higher learning rates and improve perplexity compared to using QK norm alone.
Methods of Improving LLM Training Stability: An Analysis
This paper addresses the issue of training stability in LLMs, a prominent concern within the field of machine learning, particularly as model sizes continue to scale. The authors from NVIDIA identify the growth of logits within attention layers as a significant factor contributing to training instability. By examining the output growth of all linear layers in transformer blocks, they extend upon previous works by Dehghani et al. (2023) and Wortsman et al. (2024).
Key Observations and Techniques
The authors highlight that the L2 norms of the outputs from QKV, Proj, and FC2 layers experience more than a twofold increase during divergence compared to a converging state. This observation prompts a comparative analysis of several methods aimed at enhancing training stability:
- Layer Normalization Configurations: The authors explore applying layer normalization (LN) not only after the QK layers, as was previously recommended, but also after Proj and FC2 layers. Additionally, they assess the effectiveness of replacing pre-normalization with post-normalization for QKV layers and the potential benefits of combining QK layer normalization with softmax capping.
- Softmax Adjustment Mechanisms: Adjustments to the softmax operation are considered in the forms of softmax temperature and capping. These methods aim to control the magnitude of logits, thereby mitigating the risk of gradient explosions and divergence.
- Alternative Stabilization Techniques: The paper evaluates σReparam to constrain linear layer weights, LayerScale for adaptive feature scaling, and QK layer normalization (QK norm) as existing strategies to manage training stability.
Experimental Setup and Findings
Experiments were conducted using a small-scale transformer setup similar to that of GPT-2, comprising a model with 830 million parameters. By employing a diverse dataset and varying learning rates, the paper assesses the stability and effectiveness of the proposed configurations.
- The combination of QK norm with softmax capping (QK norm cap) and the configuration where layer normalization is applied post-QKV layers (QKV norm) allowed the learning rates to increase by 1.5 times without inducing divergence, compared to models employing QK norm alone.
- Perplexity improvements were evident in models using QKV norm, QK norm cap, and QK FC norm, signifying better learning and generalization capabilities.
Implications and Future Work
The results underscore the significance of addressing output magnitude in linear layers and provide evidence that a strategic layering of normalization techniques and softmax adjustments can enhance training stability significantly. Perplexity improvements observed in these methods indicate robust learning efficiencies, contributing to both theoretical understanding and practical advancement in LLM training paradigms.
Future research directions could involve scaling these stability improvements to even larger models and datasets, testing the robustness of these techniques in various multilingual and multimodal contexts. The exploration of synergistic combinations of normalization techniques and gradient control measures will be imperative to sustaining the scaling trends of LLMs.