- The paper identifies that as sequence lengths increase, the variance in individual attention features decays, leading to a harmful distribution shift.
- The authors demonstrate that applying layer normalization to attention outputs significantly boosts accuracy on out-of-distribution longer sequences.
- Empirical results on both simple and modern models reveal that LN stabilizes global output statistics despite persistent individual feature variance decay.
This paper investigates the common problem of Transformers failing to generalize to sequence lengths longer than those seen during training. The authors propose a novel explanation termed the "vanishing variance problem".
The core idea, formalized in Proposition 1, is that under certain theoretical assumptions (i.i.d. inputs, zero-mean values), the variance of any single feature dimension in the attention output vector approaches zero as the input sequence length (N) increases infinitely. While these assumptions don't strictly hold in practice (due to positional encodings and language structure), the paper demonstrates empirically that this variance decay trend persists even in modern LLMs like Llama-3.2-1B (Figure 1) and in simpler experimental setups.
This vanishing variance leads to a distribution shift when models trained on shorter sequences are tested on longer ones. Specifically, as sequence length increases:
- The variance of individual features decreases, concentrating values around their means (Figure 2, top row).
- Consequently, the global variance across all features in the attention output vector decreases (Figure 3, right).
- The global mean of the attention output vector also drifts (Figure 3, left).
This shift is problematic because subsequent layers (like MLPs) are trained on distributions with different means and larger variances, hindering their ability to process the outputs generated from longer sequences effectively.
To mitigate this distribution shift, the paper proposes applying Layer Normalization (LN) directly to the attention outputs (O) before they are passed to subsequent layers. The LN step standardizes the output vector (adjusting for mean and variance shifts across the feature dimension) and applies learnable scale (γ) and shift (β) parameters.
Experiments conducted on a simplified single-layer, single-head Transformer without positional encodings on order-invariant tasks (argmax retrieval and dictionary lookup) demonstrate the effectiveness of this approach. Models were trained on short sequences (e.g., up to length 16) and tested on significantly longer sequences (up to length 214).
Key findings include:
- Applying LN after attention outputs significantly improves accuracy on out-of-distribution (longer) sequence lengths compared to a baseline model without this LN step (Tables 1, 2).
- This improvement is statistically significant and holds even when combined with test-time adaptation techniques (like adaptive temperature scaling).
- LN helps stabilize the global mean and variance of attention outputs when tested on longer sequences (Figure 3), although it doesn't eliminate the underlying variance decay of individual features (Figure 2, bottom row).
- LN also helps mitigate the dispersion of attention weights at longer sequence lengths (Figure 4).
- Ablation studies show that even simple standardization (LN without learnable parameters) provides benefits, but full LN is generally superior (Table 3).
The paper concludes that the vanishing variance phenomenon contributes significantly to the length generalization problem in Transformers, and applying LN after attention outputs is a practical technique to partially alleviate the resulting distribution shift and improve performance on longer sequences. Future work suggestions include validating these findings on larger models and designing architectures inherently robust to sequence length variations.