This paper, "ATLAS: Learning to Optimally Memorize the Context at Test Time" (Behrouz et al., 29 May 2025 ), introduces Atlas, a novel long-term neural memory module designed to enhance performance in tasks requiring long context understanding and memorization, overcoming limitations observed in traditional Transformers and recent linear recurrent neural networks (RNNs). The core idea is to learn to memorize the context within a sliding window at test time, rather than just individual tokens, by addressing three key shortcomings of existing models: limited memory capacity, online memory updates, and less expressive memory management.
Transformers excel at in-context retrieval but suffer from quadratic memory and time complexity with respect to sequence length, limiting their application in long contexts. Recent RNNs offer efficiency but often struggle with long-context understanding and extrapolation, primarily due to their fixed-size memory, online update nature (updating memory based only on the current input), limited memory capacity, and reliance on simple gradient-based memory optimization.
Atlas and related proposed architectures tackle these issues through several innovations:
- Enhanced Memory Capacity: The paper argues that memory capacity, defined as the maximum number of linearly independent key-value pairs a memory can perfectly map, is a bottleneck. It is shown that matrix-valued memory with a Delta update rule has sub-linear capacity in terms of its parameters. Deep memory modules (MLPs) increase capacity, but the paper proposes using higher-order feature mappings, specifically polynomial kernels , on input keys and queries. This increases the effective dimensionality of the keys, theoretically boosting memory capacity to for a matrix memory with polynomial degree . This is motivated by approximating the non-separable exponential kernel of Transformer attention and providing an input feature gating mechanism.
- Omega Rule for Context Memorization: To move beyond online updates, the paper introduces the Omega rule. Instead of optimizing the memory with respect to only the current token's error , the memory is optimized based on a sliding window of past tokens: , where is the window length and are input-dependent decay terms. This allows the model to memorize a local context, not just individual tokens, and provides in-context pruning ability via the gates. OmegaNet is presented as an architecture using this rule with polynomial kernels and deep memory.
- Improved Memory Management (Atlas): Building on the Omega rule, Atlas enhances memory management by employing the Muon optimizer [jordan2024muon] to optimize the internal memory objective. Muon approximates second-order information, aiming for a more "locally optimal" memory update than simple gradient descent. The update rule for Atlas becomes:
where represents a potentially infinite-dimensional feature map (like the exponential kernel in Transformers), and NewtonShulz- approximates the matrix inverse, leveraging matrix multiplications for efficiency.
- DeepTransformers Family: The paper reformulates the connection between associative memory and Transformers. Softmax attention can be seen as a non-parametric solution to an regression problem involving an exponential feature map and global optimization over the sequence. By introducing deep memory and different optimization objectives/rules with this , the authors derive a family of models called DeepTransformers.
- DeepTransformers: An unnormalized version using deep memory and a Hebbian-like update derived from optimizing with respect to previous memory state. This is shown to be a strict generalization of the original unnormalized Transformer output.
- Deep Omega Transformers (Dot): An unnormalized version using deep memory and the Omega rule objective with . This generalizes Transformers with a Delta rule-like update and context memorization.
Implementation Details and Parallelization:
A key practical aspect is parallelizing the Omega rule and Atlas training. A naive implementation would require materializing gradients for the entire window , which is memory-intensive. The paper proposes a chunk-wise parallelization strategy. The sequence is divided into chunks. Within each chunk, gradients are computed with respect to the memory state at the beginning of the chunk. A sliding window mask is applied during gradient computation (e.g., within einsum operations) to only include contributions from the last tokens within the current chunk and preceding chunks. For Atlas with Muon, the momentum term can be calculated recursively based on past gradients, independent of . Since gradients within a chunk are computed relative to , the momentum terms can be parallelized chunk-wise. The Newton-Schulz operation on is also parallelizable, making the overall training process efficient and parallelizable without significant overhead compared to online methods. The memory module is typically implemented as a 2-layer MLP with residual connections and optional gating (Atlas++).
Experimental Evaluation:
Extensive experiments are conducted across various benchmarks:
- LLMing & Common-Sense Reasoning: Atlas, OmegaNet, DeepTransformers, and Dot are evaluated on perplexity (Wikitext, LMB) and accuracy (LMB, PIQA, HellaSwag, WinoGrande, ARC-e/c, SIQA, BoolQ). Atlas and OmegaNet outperform recurrent baselines (RetNet, DeltaNet, Titans, Memora) and competitive hybrid models (Samba, Gated DeltaNet-H2). DeepTransformers and Dot also outperform Transformer++. This validates the benefits of deep memory, context memorization (Omega rule), and locally optimal memory management (Muon).
- Needle In a Haystack (S-NIAH): Atlas and its hybrid variants, along with DeepTransformers and Dot, show strong extrapolation performance to long contexts (up to 16K tokens), outperforming recurrent baselines. This highlights the improved effective context length and memory capacity.
- BABILong Benchmark: Atlas maintains performance up to 10M context length, achieving +80% accuracy where Titans' performance drops, demonstrating superior scaling in ultra-long sequences.
- MAD Synthetic Benchmark: Atlas achieves state-of-the-art results across tasks testing compression, noisy ICR, fuzzy ICR, selective copying, and memorization, particularly excelling in memorization, supporting the claims about enhanced memory capacity.
- In-context Recall: While Transformers remain the best on standard in-context recall tasks (SWDE, NQ, DROP, FDA, SQUAD, TQA), Atlas and OmegaNet significantly close the gap compared to other recurrent models.
- Associative Recall (MQAR): Atlas and Dot perform well, with Atlas showing the best performance per memory size among tested models.
Ablation Studies:
Ablations confirm that all proposed components contribute positively to Atlas's performance: gated MLP memory architecture, hybrid variants (MAG and MAL improve performance, with MAG being slightly better), Muon optimizer (outperforms simple gradient descent), Omega rule ( is better than ), and polynomial feature mapping.
Scaling Patterns:
Atlas and OmegaNet demonstrate favorable scaling patterns with increasing model size (up to 1.3B parameters) and training context length, achieving lower perplexity compared to baselines at larger scales.
In conclusion, the paper successfully introduces Atlas, a novel recurrent architecture with a high-capacity, deep, and locally optimal memory module that learns to memorize context using the Omega rule and Muon optimizer. It also presents DeepTransformers and Dot as generalizations of Transformers incorporating deep memory and advanced learning rules. These models achieve state-of-the-art performance among recurrent architectures and often surpass Transformer baselines on various benchmarks, particularly excelling in long-context tasks due to their enhanced memory management and capacity. The proposed parallelization strategy makes these architectures practical for large-scale training and deployment.