Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
80 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
7 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Transformer Feed-Forward Layers Are Key-Value Memories (2012.14913v2)

Published 29 Dec 2020 in cs.CL

Abstract: Feed-forward layers constitute two-thirds of a transformer model's parameters, yet their role in the network remains under-explored. We show that feed-forward layers in transformer-based LLMs operate as key-value memories, where each key correlates with textual patterns in the training examples, and each value induces a distribution over the output vocabulary. Our experiments show that the learned patterns are human-interpretable, and that lower layers tend to capture shallow patterns, while upper layers learn more semantic ones. The values complement the keys' input patterns by inducing output distributions that concentrate probability mass on tokens likely to appear immediately after each pattern, particularly in the upper layers. Finally, we demonstrate that the output of a feed-forward layer is a composition of its memories, which is subsequently refined throughout the model's layers via residual connections to produce the final output distribution.

Unveiling the Function of Feed-Forward Layers in Transformer Models

Overview

In the domain of transformer models, feed-forward layers constitute a significant proportion of parameters, yet their specific role has not been thoroughly explored. This paper details how feed-forward layers simulate key-value memory systems, essentially acting as pattern detectors that influence the model’s output distribution. The analysis unveils that different layers capture various complexities of input patterns, from shallow to more semantic, and how these patterns are integrated to form the final prediction.

Feed-Forward Layers as Neural Memory

Feed-forward layers in transformers have been under-explored despite their predomination in the model’s architecture. This paper argues that these layers function akin to key-value memories, where the first matrix of parameters represents keys that detect specific patterns in input texts, and the second matrix encapsulates values that determine the distribution over the model's output vocabulary. This functional equivalence is highlighted through the formulation of feed-forward operations and neural memory, underscoring the principal role of feed-forward layers in pattern recognition across the input data.

Patterns Captured by Keys

An experimental investigation into what these keys represent reveals that each key is associated with distinct, human-interpretable patterns within the input texts. Lower layers tend to recognize shallow patterns, such as specific n-grams, while upper layers are adept at identifying more semantic patterns, indicating a stratification in the complexity of recognized patterns across the model. This stratification supports the concept of hierarchical processing in neural networks, where initial layers focus on low-level features and higher layers on more abstract concepts.

Values as Output Distributions

Moving to the role of values, the paper shows that these can be viewed as inducing distributions over the output vocabulary that complements the input patterns detected by the keys, particularly in the model's upper layers. This relationship between keys and values grows more pronounced in higher layers, suggesting that as the model processes information hierarchically, the upper layers synthesize detected patterns to predict the next-token distribution more accurately.

Memory Aggregation and Model prediction

The paper explores how the transformer model leverages these individual memories across all layers to refine and derive the final output distribution. It demonstrates that model predictions are not solely reliant on dominant memory activations but result from the complex aggregation of multiple memory contributions, refined through residual connections across layers. This process illustrates a bottom-up assembly where detected patterns are incrementally merged and refined to form the model’s output.

Implications and Future Directions

The outlined findings bear significant implications for both theoretical understanding and practical application of transformer models. They provide a more nuanced comprehension of how feed-forward layers contribute to the model's ability to process and predict linguistic patterns. Practically, this insight could inspire more efficient model architectures and interpretability tools by focusing on the nuanced roles of these layers. Moreover, exploring how these findings translate beyond LLMs to other transformer-based applications represents an intriguing future direction.

Conclusion

In summary, by casting feed-forward layers as memory systems that recognize patterns and influence output distributions, this paper sheds light on the critical yet underscratched function of these layers within transformer models. This understanding not only enriches the current comprehension of transformer architectures but also opens new avenues for research in making these models more interpretable, efficient, and versatile.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (4)
  1. Mor Geva (58 papers)
  2. Roei Schuster (14 papers)
  3. Jonathan Berant (107 papers)
  4. Omer Levy (70 papers)
Citations (617)
Youtube Logo Streamline Icon: https://streamlinehq.com