This paper introduces a new way to improve how LLMs deal with long pieces of text by giving them a better memory.
Background and Relevance
Transformer-based models, like GPT-3 (Generative Pre-trained Transformer 3) and BERT (Bidirectional Encoder Representations from Transformers), have been very successful in many language tasks. However, they struggle when they need to understand and reason about very long contexts. This is because they sometimes have trouble finding the important information in a sea of irrelevant data.
The Large Memory Model (LM2)
To solve this problem, the researchers created the Large Memory Model (LM2), which adds a special memory module to the standard Transformer architecture. This memory module acts like a separate storage space where the model can keep track of important information it has seen.
Here's how the LM2 works:
- Memory Initialization: The memory module starts with a "memory bank," which is like a collection of slots where information can be stored. Each slot is initially set to a neutral state.
- The memory bank is represented by , where:
- is the number of memory slots.
- is the hidden dimension of each slot.
- denotes real numbers.
- Each memory slot is initialized as an identity matrix: , where and is the identity matrix.
- The memory bank is represented by , where:
- Memory Information Flow: When the model processes new input, it uses a technique called "cross attention" to compare the input to the information stored in the memory bank. This helps the model find the memory slots that contain the most relevant information.
- Input embeddings act as the query, while the memory bank serves as both the key and the value store.
The input embeddings (where is the sequence length) and memory bank are projected into query (), key (), and value () spaces:
where are learnable projection matrices, and stands for decoder block .
- is the input embedding at decoder block
- is the memory bank at decoder block
- is the query projection matrix
- is the key projection matrix
- is the value projection matrix
* The attention scores are computed as the scaled dot product of the query and key matrices:
where represents the alignment between the input sequence and memory slots. * represents the query matrix. * represents the key matrix. * represents the hidden dimension of each slot. * denotes real numbers. * is the sequence length. * is the number of memory slots.
* The resultant attention output is
where integrates information from the input and memory. * represents the alignment between the input sequence and memory slots. * represents the value matrix. * is the sequence length. * is the hidden dimension of each slot. * denotes real numbers.
* To control how much the memory influences the model's output, an "output gate" is used. This gate decides how much of the information retrieved from memory should be passed on to the next layer.
where is a learnable parameter matrix, and is the sigmoid activation function. * is the output gate. * integrates information from the input and memory. * is a learnable parameter matrix. * is the sigmoid activation function. * is the hidden dimension of each slot. * denotes real numbers.
* The gated memory output is then computed as:
* is the gated memory output. * is the output gate. * is the memory bank at decoder block .
* The gated memory output is integrated into the standard attention flow of the Transformer decoder through a skip connection. Specifically, the output of the self-attention mechanism, , is combined with the gated memory output as
* represents the combined output that is passed to the next decoder layer. * is the output of the self-attention mechanism. * is the gated memory output.
- Memory Updates: The memory module also needs to update its contents to store new information and remove irrelevant information. This is done using three "gates":
Input Gate: This gate decides how much of the new input should be written into the memory.
where is a learnable parameter matrix, is the current input representation, and is the sigmoid activation function.
- is the input gate
- is a learnable parameter matrix
- is the current input representation
- is the sigmoid activation function.
- denotes real numbers.
* Forget Gate: This gate decides which parts of the existing memory should be erased or forgotten.
where . * is the forget gate * is a learnable parameter matrix * integrates information from the input and memory. * is the sigmoid activation function. * denotes real numbers.
* Output Gate: This gate, described earlier, controls how much of the memory content is used to generate the final output.
* The updated memory state is:
where a function is applied to keep the new memory content bounded. * is the updated memory state * is the input gate * integrates information from the input and memory. * is the forget gate * is the hyperbolic tangent function * is the current memory state
Experiments and Results
The researchers tested the LM2 on a dataset called BABILong, which is designed to test how well models can reason about long contexts. The LM2 outperformed other models, including a memory-augmented model called RMT (Recurrent Memory Transformer) and a baseline model called Llama-3.2. The LM2 was better at multi-hop inference (answering questions that require multiple steps of reasoning), numerical reasoning, and question-answering in long contexts.
To make sure that the memory module didn't hurt the model's ability to perform general tasks, the researchers also tested it on the MMLU dataset, which covers a wide range of academic subjects. The LM2 performed better than a standard Llama model, showing that the memory module can actually improve performance on general tasks as well.
Key Findings
- The LM2's memory module helps it to better understand and reason about long contexts.
- The memory module does not degrade performance on general tasks and can even improve it.
- The way the memory module is integrated into the Transformer architecture is important for achieving the best performance.
- The memory module stores and retrieves information in a way that is relevant to the task at hand.
- The memory module adapts its contents during testing to focus on the most important information.
Contributions
- The paper introduces a new memory-augmented Transformer architecture that can capture and use long-term dependencies in data.
- The paper proposes a new way to integrate memory into the Transformer architecture, which allows the model to maintain its original capabilities while also benefiting from the memory module.
- The paper shows that the LM2 outperforms existing models on long context reasoning tasks.
In summary, this paper presents a promising new approach to improving the ability of LLMs to handle long contexts by giving them a better memory.