Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
98 tokens/sec
GPT-4o
61 tokens/sec
Gemini 2.5 Pro Pro
46 tokens/sec
o3 Pro
8 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

LM2: Large Memory Models (2502.06049v1)

Published 9 Feb 2025 in cs.CL and cs.AI
LM2: Large Memory Models

Abstract: This paper introduces the Large Memory Model (LM2), a decoder-only Transformer architecture enhanced with an auxiliary memory module that aims to address the limitations of standard Transformers in multi-step reasoning, relational argumentation, and synthesizing information distributed over long contexts. The proposed LM2 incorporates a memory module that acts as a contextual representation repository, interacting with input tokens via cross attention and updating through gating mechanisms. To preserve the Transformers general-purpose capabilities, LM2 maintains the original information flow while integrating a complementary memory pathway. Experimental results on the BABILong benchmark demonstrate that the LM2model outperforms both the memory-augmented RMT model by 37.1% and the baseline Llama-3.2 model by 86.3% on average across tasks. LM2 exhibits exceptional capabilities in multi-hop inference, numerical reasoning, and large-context question-answering. On the MMLU dataset, it achieves a 5.0% improvement over a pre-trained vanilla model, demonstrating that its memory module does not degrade performance on general tasks. Further, in our analysis, we explore the memory interpretability, effectiveness of memory modules, and test-time behavior. Our findings emphasize the importance of explicit memory in enhancing Transformer architectures.

This paper introduces a new way to improve how LLMs deal with long pieces of text by giving them a better memory.

Background and Relevance

Transformer-based models, like GPT-3 (Generative Pre-trained Transformer 3) and BERT (Bidirectional Encoder Representations from Transformers), have been very successful in many language tasks. However, they struggle when they need to understand and reason about very long contexts. This is because they sometimes have trouble finding the important information in a sea of irrelevant data.

The Large Memory Model (LM2)

To solve this problem, the researchers created the Large Memory Model (LM2), which adds a special memory module to the standard Transformer architecture. This memory module acts like a separate storage space where the model can keep track of important information it has seen.

Here's how the LM2 works:

  1. Memory Initialization: The memory module starts with a "memory bank," which is like a collection of slots where information can be stored. Each slot is initially set to a neutral state.
    • The memory bank is represented by MRN×d×d\mathbf{M} \in \mathbb{R}^{N \times d \times d}, where:
      • NN is the number of memory slots.
      • dd is the hidden dimension of each slot.
      • R\mathbb{R} denotes real numbers.
      • Each memory slot is initialized as an identity matrix: Mr=Id×d\mathbf{M}_r = \mathbf{I}_{d \times d}, where r{1,,N}r \in \{1, \dots, N\} and Id×d\mathbf{I}_{d \times d} is the identity matrix.
  2. Memory Information Flow: When the model processes new input, it uses a technique called "cross attention" to compare the input to the information stored in the memory bank. This helps the model find the memory slots that contain the most relevant information.
    • Input embeddings E\mathbf{E} act as the query, while the memory bank M\mathbf{M} serves as both the key and the value store.
    • The input embeddings ERT×d\mathbf{E} \in \mathbb{R}^{T \times d} (where TT is the sequence length) and memory bank MRN×d\mathbf{M} \in \mathbb{R}^{N \times d} are projected into query (Q\mathbf{Q}), key (K\mathbf{K}), and value (V\mathbf{V}) spaces:

      Q=EtWQ,K=MtWK,V=MtWV,\mathbf{Q} = \mathbf{E}_t \mathbf{W}^Q, \quad \mathbf{K} = \mathbf{M}_t \mathbf{W}^K, \quad \mathbf{V} = \mathbf{M}_t \mathbf{W}^V,

      where WQ,WK,WVRd×d\mathbf{W}^Q, \mathbf{W}^K, \mathbf{W}^V \in \mathbb{R}^{d \times d} are learnable projection matrices, and tt stands for decoder block tt.

      • Et\mathbf{E}_t is the input embedding at decoder block tt
      • Mt\mathbf{M}_t is the memory bank at decoder block tt
      • WQ\mathbf{W}^Q is the query projection matrix
      • WK\mathbf{W}^K is the key projection matrix
      • WV\mathbf{W}^V is the value projection matrix

* The attention scores are computed as the scaled dot product of the query and key matrices:

A=softmax(QKd),\mathbf{A} = \text{softmax}\left(\frac{\mathbf{Q} \mathbf{K}^\top}{\sqrt{d}}\right),

where ART×N\mathbf{A} \in \mathbb{R}^{T \times N} represents the alignment between the input sequence and memory slots. * Q\mathbf{Q} represents the query matrix. * K\mathbf{K} represents the key matrix. * dd represents the hidden dimension of each slot. * R\mathbb{R} denotes real numbers. * TT is the sequence length. * NN is the number of memory slots.

* The resultant attention output is

Emem=AV,\mathbf{E}_\text{mem} = \mathbf{A} \mathbf{V},

where EmemRT×d\mathbf{E}_\text{mem} \in \mathbb{R}^{T \times d} integrates information from the input and memory. * A\mathbf{A} represents the alignment between the input sequence and memory slots. * V\mathbf{V} represents the value matrix. * TT is the sequence length. * dd is the hidden dimension of each slot. * R\mathbb{R} denotes real numbers.

* To control how much the memory influences the model's output, an "output gate" is used. This gate decides how much of the information retrieved from memory should be passed on to the next layer.

gout=σ(EmemWout),g_\text{out} = \sigma\left(\mathbf{E}_\text{mem} \mathbf{W}_\text{out}\right),

where WoutRd×d\mathbf{W}_\text{out} \in \mathbb{R}^{d \times d} is a learnable parameter matrix, and σ\sigma is the sigmoid activation function. * goutg_\text{out} is the output gate. * Emem\mathbf{E}_\text{mem} integrates information from the input and memory. * Wout\mathbf{W}_\text{out} is a learnable parameter matrix. * σ\sigma is the sigmoid activation function. * dd is the hidden dimension of each slot. * R\mathbb{R} denotes real numbers.

* The gated memory output is then computed as:

Egated=goutMt.\mathbf{E}_\text{gated} = g_\text{out} \cdot \mathbf{M}_t.

* Egated\mathbf{E}_\text{gated} is the gated memory output. * goutg_\text{out} is the output gate. * Mt\mathbf{M}_t is the memory bank at decoder block tt.

* The gated memory output is integrated into the standard attention flow of the Transformer decoder through a skip connection. Specifically, the output of the self-attention mechanism, Eattn\mathbf{E}_\text{attn}, is combined with the gated memory output as

Enext=Eattn+Egated.\mathbf{E}_\text{next} = \mathbf{E}_\text{attn} + \mathbf{E}_\text{gated}.

* Enext\mathbf{E}_\text{next} represents the combined output that is passed to the next decoder layer. * Eattn\mathbf{E}_\text{attn} is the output of the self-attention mechanism. * Egated\mathbf{E}_\text{gated} is the gated memory output.

  1. Memory Updates: The memory module also needs to update its contents to store new information and remove irrelevant information. This is done using three "gates":
    • Input Gate: This gate decides how much of the new input should be written into the memory.

      gin=σ(EtWin),g_\text{in} = \sigma\bigl(\mathbf{E}_t \mathbf{W}_\text{in}\bigr),

      where WinRd×d\mathbf{W}_\text{in} \in \mathbb{R}^{d \times d} is a learnable parameter matrix, Et\mathbf{E}_t is the current input representation, and σ\sigma is the sigmoid activation function.

      • ging_\text{in} is the input gate
      • Win\mathbf{W}_\text{in} is a learnable parameter matrix
      • Et\mathbf{E}_t is the current input representation
      • σ\sigma is the sigmoid activation function.
      • R\mathbb{R} denotes real numbers.

* Forget Gate: This gate decides which parts of the existing memory should be erased or forgotten.

gforget=σ(EmemWforget),g_\text{forget} = \sigma\bigl(\mathbf{E}_\text{mem} \mathbf{W}_\text{forget}\bigr),

where WforgetRd×d\mathbf{W}_\text{forget} \in \mathbb{R}^{d \times d}. * gforgetg_\text{forget} is the forget gate * Wforget\mathbf{W}_\text{forget} is a learnable parameter matrix * Emem\mathbf{E}_\text{mem} integrates information from the input and memory. * σ\sigma is the sigmoid activation function. * R\mathbb{R} denotes real numbers.

* Output Gate: This gate, described earlier, controls how much of the memory content is used to generate the final output.

* The updated memory state is:

Mt+1=gintanh(Emem)  +  gforgetMt,\mathbf{M}_{t+1} = g_\text{in} \cdot \tanh(\mathbf{E}_\text{mem}) \;+\; g_\text{forget} \cdot \mathbf{M}_{t},

where a tanh\tanh function is applied to keep the new memory content bounded. * Mt+1\mathbf{M}_{t+1} is the updated memory state * ging_\text{in} is the input gate * Emem\mathbf{E}_\text{mem} integrates information from the input and memory. * gforgetg_\text{forget} is the forget gate * tanh\tanh is the hyperbolic tangent function * Mt\mathbf{M}_{t} is the current memory state

Experiments and Results

The researchers tested the LM2 on a dataset called BABILong, which is designed to test how well models can reason about long contexts. The LM2 outperformed other models, including a memory-augmented model called RMT (Recurrent Memory Transformer) and a baseline model called Llama-3.2. The LM2 was better at multi-hop inference (answering questions that require multiple steps of reasoning), numerical reasoning, and question-answering in long contexts.

To make sure that the memory module didn't hurt the model's ability to perform general tasks, the researchers also tested it on the MMLU dataset, which covers a wide range of academic subjects. The LM2 performed better than a standard Llama model, showing that the memory module can actually improve performance on general tasks as well.

Key Findings

  • The LM2's memory module helps it to better understand and reason about long contexts.
  • The memory module does not degrade performance on general tasks and can even improve it.
  • The way the memory module is integrated into the Transformer architecture is important for achieving the best performance.
  • The memory module stores and retrieves information in a way that is relevant to the task at hand.
  • The memory module adapts its contents during testing to focus on the most important information.

Contributions

  • The paper introduces a new memory-augmented Transformer architecture that can capture and use long-term dependencies in data.
  • The paper proposes a new way to integrate memory into the Transformer architecture, which allows the model to maintain its original capabilities while also benefiting from the memory module.
  • The paper shows that the LM2 outperforms existing models on long context reasoning tasks.

In summary, this paper presents a promising new approach to improving the ability of LLMs to handle long contexts by giving them a better memory.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (8)
  1. Jikun Kang (7 papers)
  2. Wenqi Wu (8 papers)
  3. Filippos Christianos (19 papers)
  4. Alex J. Chan (15 papers)
  5. Fraser Greenlee (2 papers)
  6. George Thomas (21 papers)
  7. Marvin Purtorab (2 papers)
  8. Andy Toulis (2 papers)
Youtube Logo Streamline Icon: https://streamlinehq.com

HackerNews

  1. LM2: Large Memory Models (110 points, 30 comments)
Reddit Logo Streamline Icon: https://streamlinehq.com

Reddit

  1. LM2: Large Memeory Models (23 points, 9 comments)