- The paper finds that transformer decoders demonstrate near-perfect linearity (score ≈0.99) primarily due to the influence of residual connections.
- It shows that applying cosine similarity regularization during pretraining reduces layer linearity and enhances performance on benchmarks like SuperGLUE.
- The research proposes pruning and distillation techniques that replace highly linear layers with linear approximations to reduce model size while maintaining accuracy.
This paper (2405.12250) investigates the linearity properties of transformer decoders, such as GPT, LLaMA, OPT, and BLOOM, finding a surprisingly high degree of linearity in the transformations between sequential layers. This is quantified using a modified Procrustes similarity score, which measures how close the transformation between the embedding spaces of two consecutive layers is to a linear mapping. A score near 1 indicates high linearity. The paper finds this score is consistently close to 0.99 across different models.
A key insight is that this apparent linearity is heavily influenced by the residual connections. The norm of the output of individual transformer blocks (which is added to the residual stream) is remarkably low compared to the norm of the embeddings in the residual stream itself. This means the layer's contribution to the residual stream is small, leading to adjacent embeddings in the main stream being very similar and thus appearing highly linearly related. When the residual component is removed, the linearity score of the main stream (embeddings w/o residual component) decreases significantly, suggesting the core block transformations are less linear than the high overall score implies.
The paper analyzes how linearity changes during training. It observes that during pretraining, the average linearity score across layers decreases. However, during fine-tuning on various tasks (like SuperGLUE or reward modeling), linearity consistently increases (Table 1). This suggests that task-specific fine-tuning tends to reinforce the model's linear characteristics.
Based on these findings, the authors propose practical methods for model optimization:
- Regularized Pretraining: The paper explores adding regularization terms during pretraining to explicitly influence layer linearity. Two types of regularization are tested:
- MSE regularization: LMSE=λ∑∥embi−embi−1∥2, aiming to minimize the distance between consecutive embeddings.
- Cosine Similarity regularization: Lcosine=λ∑(1−cos(embi,embi−1)), encouraging consecutive embeddings to have high cosine similarity (close to 1).
The cosine similarity regularization showed promising results. When applied during pretraining of small Mistral models (150M, 650M) on datasets like TinyStories and Tiny-textbooks, it successfully decreased layer linearity (Figure 1) while simultaneously improving performance on benchmarks like SuperGLUE (Table 2) and TinyStories evaluation (Table 3). This counter-intuitive result suggests that reducing linearity might push the model to leverage non-linearities more effectively in other parts of the architecture or residual stream.
Implementation Note: To implement cosine similarity regularization, you would modify the standard LLM pretraining loss. During a forward pass, you would capture the output embeddings of each transformer layer. Then, calculate the cosine similarity between the output of layer i and the output of layer i−1 for all i>0. The regularization term would be the sum of 1−cos(embi,embi−1) across all layers, scaled by a hyperparameter λ, and added to the standard LLMing loss. This requires modifying the model's forward pass to return intermediate embeddings and adding the regularization calculation to the training loop.
- Pruning and Distillation: Leveraging the linearity observation, the paper proposes pruning the most linear layers. This allows for a slight reduction in model size with minimal performance loss. An enhanced strategy involves:
- Identifying and removing the most linear layers based on their linearity score.
- Replacing these removed layers with simple linear approximations (e.g., a dense matrix multiplication matching input/output dimensions).
- Applying distillation during fine-tuning or a dedicated training phase, specifically using layer-wise MSE loss to match the outputs of the pruned/linear-approximated layers to the outputs of the corresponding layers in the original, larger model (teacher model).
Implementation Note: To implement this pruning strategy:
* First, calculate linearity scores for all layers of the pretrained model on a representative dataset. You'll need access to intermediate layer outputs. The formula is:
linearity_score:=1−A∈Rd×dmin∣∣X~A−Y~∣∣22
where X~ and Y~ are centered and Frobenius-normalized embedding matrices from consecutive layers. Finding the optimal linear transformation A can be done using linear regression or singular value decomposition.
* Rank layers by their linearity score and select the desired number of layers to remove or replace.
* For simple removal (depth pruning), just remove the selected layers from the model's sequential structure.
* For linear approximation, replace the full transformer block (multi-head attention, MLP, LayerNorms, etc.) with a single torch.nn.Linear
layer that maps from the input hidden dimension to the output hidden dimension.
* For distillation, load the original, larger model as the "teacher". Load the pruned/approximated model as the "student". During training, pass the same input data through both models. Capture the intermediate outputs of corresponding layers in the teacher and student. Calculate the MSE loss between these intermediate outputs. The total student loss would be the sum of the original task loss (e.g., LLMing loss) and the distillation MSE loss, potentially weighted. Training would then update the student model's parameters, especially the newly added linear layers if applicable.
The paper demonstrates that pruning with linear replacements and distillation leads to lower perplexity compared to just removing layers (Figure 2, Figure 3). Results on tasks like ARC-easy also show performance retention with these techniques (Figure 4).
Practical Implications and Considerations:
- Model Compression: The pruning techniques offer a way to reduce model size and computational cost for inference, which is critical for deploying LLMs on resource-constrained devices or reducing inference latency.
- Optimization Strategy: Identifying highly linear layers can guide structural pruning efforts, focusing on parts of the model that might be less critical for complex non-linear computations.
- Pretraining Regularization: The cosine regularization suggests a novel way to potentially improve model efficiency and performance during the initial training phase, rather than relying solely on post-hoc optimization techniques. This could lead to intrinsically more efficient architectures.
- Computational Cost: Measuring linearity requires forward passes and matrix computations on intermediate embeddings. Regularized pretraining adds computation to the training loop. Pruning reduces inference cost but distillation requires additional training time and compute.
- Generalizability: The analysis focuses on decoders. Applying these findings directly to encoder-only or encoder-decoder models would require further investigation. The effectiveness of pruning and regularization techniques may also vary depending on the specific model architecture, size, and downstream task.
Overall, this research provides empirical evidence for the surprising linearity in transformer decoders and translates this observation into concrete methods for model optimization through regularization, pruning, and distillation.