Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
72 tokens/sec
GPT-4o
61 tokens/sec
Gemini 2.5 Pro Pro
44 tokens/sec
o3 Pro
8 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

ATLAS: Learning to Optimally Memorize the Context at Test Time (2505.23735v1)

Published 29 May 2025 in cs.CL and cs.AI

Abstract: Transformers have been established as the most popular backbones in sequence modeling, mainly due to their effectiveness in in-context retrieval tasks and the ability to learn at scale. Their quadratic memory and time complexity, however, bound their applicability in longer sequences and so has motivated researchers to explore effective alternative architectures such as modern recurrent neural networks (a.k.a long-term recurrent memory module). Despite their recent success in diverse downstream tasks, they struggle in tasks that requires long context understanding and extrapolation to longer sequences. We observe that these shortcomings come from three disjoint aspects in their design: (1) limited memory capacity that is bounded by the architecture of memory and feature mapping of the input; (2) online nature of update, i.e., optimizing the memory only with respect to the last input; and (3) less expressive management of their fixed-size memory. To enhance all these three aspects, we present ATLAS, a long-term memory module with high capacity that learns to memorize the context by optimizing the memory based on the current and past tokens, overcoming the online nature of long-term memory models. Building on this insight, we present a new family of Transformer-like architectures, called DeepTransformers, that are strict generalizations of the original Transformer architecture. Our experimental results on LLMing, common-sense reasoning, recall-intensive, and long-context understanding tasks show that ATLAS surpasses the performance of Transformers and recent linear recurrent models. ATLAS further improves the long context performance of Titans, achieving +80\% accuracy in 10M context length of BABILong benchmark.

This paper, "ATLAS: Learning to Optimally Memorize the Context at Test Time" (Behrouz et al., 29 May 2025 ), introduces Atlas, a novel long-term neural memory module designed to enhance performance in tasks requiring long context understanding and memorization, overcoming limitations observed in traditional Transformers and recent linear recurrent neural networks (RNNs). The core idea is to learn to memorize the context within a sliding window at test time, rather than just individual tokens, by addressing three key shortcomings of existing models: limited memory capacity, online memory updates, and less expressive memory management.

Transformers excel at in-context retrieval but suffer from quadratic memory and time complexity with respect to sequence length, limiting their application in long contexts. Recent RNNs offer efficiency but often struggle with long-context understanding and extrapolation, primarily due to their fixed-size memory, online update nature (updating memory based only on the current input), limited memory capacity, and reliance on simple gradient-based memory optimization.

Atlas and related proposed architectures tackle these issues through several innovations:

  1. Enhanced Memory Capacity: The paper argues that memory capacity, defined as the maximum number of linearly independent key-value pairs a memory can perfectly map, is a bottleneck. It is shown that matrix-valued memory with a Delta update rule has sub-linear capacity in terms of its parameters. Deep memory modules (MLPs) increase capacity, but the paper proposes using higher-order feature mappings, specifically polynomial kernels ϕp()\phi_p(\cdot), on input keys and queries. This increases the effective dimensionality of the keys, theoretically boosting memory capacity to O(dkp)\mathcal{O}(d_k^p) for a matrix memory with polynomial degree pp. This is motivated by approximating the non-separable exponential kernel of Transformer attention and providing an input feature gating mechanism.
  2. Omega Rule for Context Memorization: To move beyond online updates, the paper introduces the Omega rule. Instead of optimizing the memory with respect to only the current token's error (M;kt,vt)\ell(M; k_t, v_t), the memory is optimized based on a sliding window of past tokens: minMi=tc+1tγi(t)M(ϕ(ki))vi22\min_{M} \sum_{i = t - c + 1}^{t} \gamma^{(t)}_i \left\| M\left(\phi(k_i)\right) - v_i \right\|^2_2, where cc is the window length and γi(t)\gamma^{(t)}_i are input-dependent decay terms. This allows the model to memorize a local context, not just individual tokens, and provides in-context pruning ability via the γi(t)\gamma^{(t)}_i gates. OmegaNet is presented as an architecture using this rule with polynomial kernels and deep memory.
  3. Improved Memory Management (Atlas): Building on the Omega rule, Atlas enhances memory management by employing the Muon optimizer [jordan2024muon] to optimize the internal memory objective. Muon approximates second-order information, aiming for a more "locally optimal" memory update than simple gradient descent. The update rule for Atlas becomes:

    Mt=αtMt1ηtNewtonShulzk(St)M_t = \alpha_t M_{t-1} - \eta_t \: NewtonShulz-k(\mathcal{S}_t)

    St=θtSt1+i=tc+1tγi(t)M(ϕ(ki))vi22\mathcal{S}_t = \theta_t S_{t-1} + \nabla \sum_{i = t - c + 1}^{t} \gamma^{(t)}_i \left\| M\left(\phi^{*}(k_i)\right) - v_i \right\|^2_2

    where ϕ()\phi^{*}(\cdot) represents a potentially infinite-dimensional feature map (like the exponential kernel in Transformers), and NewtonShulz-k()k(\cdot) approximates the matrix inverse, leveraging matrix multiplications for efficiency.

  4. DeepTransformers Family: The paper reformulates the connection between associative memory and Transformers. Softmax attention can be seen as a non-parametric solution to an 2\ell_2 regression problem involving an exponential feature map ϕ()\phi^{*}(\cdot) and global optimization over the sequence. By introducing deep memory and different optimization objectives/rules with this ϕ()\phi^{*}(\cdot), the authors derive a family of models called DeepTransformers.
    • DeepTransformers: An unnormalized version using deep memory and a Hebbian-like update derived from optimizing (M(ϕ(ki)),vi)\sum \ell(M(\phi^*(k_i)), v_i) with respect to previous memory state. This is shown to be a strict generalization of the original unnormalized Transformer output.
    • Deep Omega Transformers (Dot): An unnormalized version using deep memory and the Omega rule objective with ϕ()\phi^{*}(\cdot). This generalizes Transformers with a Delta rule-like update and context memorization.

Implementation Details and Parallelization:

A key practical aspect is parallelizing the Omega rule and Atlas training. A naive implementation would require materializing gradients for the entire window cc, which is memory-intensive. The paper proposes a chunk-wise parallelization strategy. The sequence is divided into chunks. Within each chunk, gradients are computed with respect to the memory state at the beginning of the chunk. A sliding window mask is applied during gradient computation (e.g., within einsum operations) to only include contributions from the last cc tokens within the current chunk and preceding chunks. For Atlas with Muon, the momentum term StS_t can be calculated recursively based on past gradients, independent of MtM_t. Since gradients within a chunk are computed relative to MtM_{t'}, the momentum terms can be parallelized chunk-wise. The Newton-Schulz operation on StS_t is also parallelizable, making the overall training process efficient and parallelizable without significant overhead compared to online methods. The memory module M()M(\cdot) is typically implemented as a 2-layer MLP with residual connections and optional gating (Atlas++).

Experimental Evaluation:

Extensive experiments are conducted across various benchmarks:

  • LLMing & Common-Sense Reasoning: Atlas, OmegaNet, DeepTransformers, and Dot are evaluated on perplexity (Wikitext, LMB) and accuracy (LMB, PIQA, HellaSwag, WinoGrande, ARC-e/c, SIQA, BoolQ). Atlas and OmegaNet outperform recurrent baselines (RetNet, DeltaNet, Titans, Memora) and competitive hybrid models (Samba, Gated DeltaNet-H2). DeepTransformers and Dot also outperform Transformer++. This validates the benefits of deep memory, context memorization (Omega rule), and locally optimal memory management (Muon).
  • Needle In a Haystack (S-NIAH): Atlas and its hybrid variants, along with DeepTransformers and Dot, show strong extrapolation performance to long contexts (up to 16K tokens), outperforming recurrent baselines. This highlights the improved effective context length and memory capacity.
  • BABILong Benchmark: Atlas maintains performance up to 10M context length, achieving +80% accuracy where Titans' performance drops, demonstrating superior scaling in ultra-long sequences.
  • MAD Synthetic Benchmark: Atlas achieves state-of-the-art results across tasks testing compression, noisy ICR, fuzzy ICR, selective copying, and memorization, particularly excelling in memorization, supporting the claims about enhanced memory capacity.
  • In-context Recall: While Transformers remain the best on standard in-context recall tasks (SWDE, NQ, DROP, FDA, SQUAD, TQA), Atlas and OmegaNet significantly close the gap compared to other recurrent models.
  • Associative Recall (MQAR): Atlas and Dot perform well, with Atlas showing the best performance per memory size among tested models.

Ablation Studies:

Ablations confirm that all proposed components contribute positively to Atlas's performance: gated MLP memory architecture, hybrid variants (MAG and MAL improve performance, with MAG being slightly better), Muon optimizer (outperforms simple gradient descent), Omega rule (c>1c>1 is better than c=1c=1), and polynomial feature mapping.

Scaling Patterns:

Atlas and OmegaNet demonstrate favorable scaling patterns with increasing model size (up to 1.3B parameters) and training context length, achieving lower perplexity compared to baselines at larger scales.

In conclusion, the paper successfully introduces Atlas, a novel recurrent architecture with a high-capacity, deep, and locally optimal memory module that learns to memorize context using the Omega rule and Muon optimizer. It also presents DeepTransformers and Dot as generalizations of Transformers incorporating deep memory and advanced learning rules. These models achieve state-of-the-art performance among recurrent architectures and often surpass Transformer baselines on various benchmarks, particularly excelling in long-context tasks due to their enhanced memory management and capacity. The proposed parallelization strategy makes these architectures practical for large-scale training and deployment.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (8)
  1. Ali Behrouz (17 papers)
  2. Zeman Li (5 papers)
  3. Praneeth Kacham (15 papers)
  4. Majid Daliri (11 papers)
  5. Yuan Deng (21 papers)
  6. Peilin Zhong (40 papers)
  7. Meisam Razaviyayn (76 papers)
  8. Vahab Mirrokni (153 papers)
Youtube Logo Streamline Icon: https://streamlinehq.com