The Expressivity Role of LayerNorm in Transformers' Attention
The paper explores an underexplored facet of Layer Normalization (LayerNorm) within Transformer architectures, refuting the commonly held notion that LayerNorm's sole utility is during the forward pass normalization of activations, as well as gradient stabilization during the backward pass. The authors present a geometric interpretation of LayerNorm, unraveling its critical contributions to the expressivity of the multi-head attention mechanism in Transformers.
Geometric Interpretation and Components
LayerNorm operates through two combined geometric transformations:
- Projection: Input vectors are projected onto a -dimensional space orthogonal to the vector.
- Scaling: The projected vectors are rescaled to have a uniform norm of .
These components provide foundational support to the subsequent attention layer.
Importance of Projection
The projection aspect of LayerNorm simplifies the creation of attention queries that uniformly attend to all keys. This operation offloads specific computations to LayerNorm, making it simpler for the attention mechanism to learn certain functions. This facet was experimentally demonstrated in tasks like computing the "majority" token type in a dataset, where the Transformer converged faster due to the orthogonal projection of keys, which made it easier for the queries to attend equally to any key.
Necessity of Scaling
Scaling alleviates the issue of "unselectable" keys, ensuring that any key vector can potentially receive the highest attention score. Without scaling, vectors that lie within the convex hull of other keys cannot be assigned the highest attention score, thereby limiting the attentional focus of the Transformer. Empirical results corroborate this, showing a substantial fraction of unselectable keys in models without scaling, which translates into poorer performance and slower convergence.
Experimental Results
Several experiments underscore the practical benefits of LayerNorm's components:
- Majority Function Computation: Transformers with LayerNorm aligned the queries to be orthogonal to the keys more effectively and converged faster.
- LLMing: Transformers without scaling exhibited a higher proportion of unselectable keys, resulting in higher training and validation losses compared to those with LayerNorm.
Implications
Theoretically, these findings highlight LayerNorm as more than a mere normalization technique. It plays a pivotal role in enhancing the expressive power of attention mechanisms in Transformers by restructuring the feature space into a more learnable form.
Practically, the insights imply certain design choices for architects of neural networks, especially in small models or those operating on long sequences. Employing LayerNorm before attention layers can significantly enhance learning efficiency and model performance.
Future Directions
Building on these results, several intriguing questions arise:
- Could alternative normal vectors, learnable per layer, offer additional benefits?
- What would be the effects of enforcing orthogonality to multiple vectors in each layer?
These questions pave the way for further research into the geometric properties of normalization techniques and their implications for the expressivity of neural architectures.
In conclusion, this paper substantiates that LayerNorm significantly contributes to the efficacy of attention in Transformers. This understanding should shape both future theoretical explorations and the practical design of Transformer-based models. The authors make their code available, providing a resource for continued investigation into these mechanisms.