Fast-WAM: Fast-Weight Associative Memory
- Fast-WAM is a neural architecture that integrates an LSTM controller with a fast-weight memory to support rapid, heterogeneous associative binding.
- It uses a third-order tensor for storing key-value mappings via outer products, enabling multi-step, compositional inference over sequential data.
- Empirical results show Fast-WAM outperforms traditional models on structured reasoning and meta-reinforcement learning while maintaining a compact parameter footprint.
Fast-Weight Associative Memory (Fast-WAM) refers to a neural network architecture that augments standard recurrent models with a rapidly updated, compositional associative memory. The design draws upon principles of fast weights and hetero-associative binding, enabling efficient storage and inference over arbitrary associations between keys and values within sequential data. Fast-WAM models have demonstrated significant empirical gains on structured reasoning benchmarks, meta-reinforcement learning for partially observable Markov decision processes, and challenging compositional language tasks, all while maintaining a compact parameter footprint (Schlag et al., 2020).
1. Architectural Overview
Fast-WAM consists of two principal components: a slow-weight controller, typically a Long Short-Term Memory (LSTM) network, and a fast-weight memory (FWM) implemented as a third-order tensor. The LSTM provides a recurrent hidden state and cell state , updated per standard recurrence equations; these serve as the basis for generating both the memory write and read signals.
The fast-weight memory is reshaped for computational convenience as , providing storage for hetero-associative mappings between key-pairs and value vectors. At each time step, the system performs:
- A write: binding a new association.
- A read: one or more inference steps to retrieve values via content-based lookup.
All slow weights are learned by gradient descent, and the system trains end-to-end by backpropagation through time.
2. Memory Update and Retrieval Mechanisms
The fundamental mechanics of Fast-WAM center on the explicit, rapidly modifiable memory updates and compositional reading procedures.
Memory Write (Update) Rule
Upon receiving the LSTM hidden state , linear projections produce write-gate , keys , and write-value via:
The memory update first removes any previous association at the location, then inserts the new value:
This rule implements a single, gated Hebbian-like step to manage compositional bindings efficiently and avoid memory explosion.
Memory Read (Retrieval) Rule
A series of content-based queries allow for multi-step inference:
For to , the intermediate retrieved values are recursively computed: The final retrieved value is projected and added residually to the controller output: A softmax operates on to produce the predictive output.
3. End-to-End Workflow and Pseudocode
The following sequence details the full forward computation:
- Token Embedding: .
- LSTM Update: .
- Write to FWM: Compute write gate, keys, and value. Retrieve . Apply the Hebbian update and renormalization.
- Multi-step Read: Generate query vectors and perform iterative content-based lookups.
- Output Computation: Add retrieved memory output to controller state, apply an output projection and softmax.
A succinct pseudocode illustration:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
(h, c) = 0, A = 0 for t in 1...T: e = Embedding[x_t] (h, c) = LSTM(e, h, c) # Write to fast weights [w_beta, w_w] = W_write @ h beta = sigmoid(w_beta) [k1, k2, v] = tanh(w_w) v_old = A @ vectorize(k1 ⊗ k2) A_ = A + beta * vectorize(k1 ⊗ k2) ⊗ (v - v_old) A = A_ / max(1, norm(A_)) # Read from fast weights u0 = tanh(W_n @ h) for i in range(1, N_r+1): ui = tanh(W_e[i] @ h) m = LayerNorm(A @ vectorize(u_{i-1} ⊗ u_i)) m = W_o @ m r = h + m y_{t+1} = softmax(W_out @ r) |
4. Empirical Performance and Ablations
Fast-WAM achieves strong empirical results on challenging compositional reasoning and sequential inference tasks:
- On concatenated-bAbI (catbAbI), Fast-WAM attains 96.8% QA-mode accuracy, surpassing LSTM (81%), Transformer-XL (87.7%), and Metalearned Neural Memory (89.0%). With read, accuracy exceeding 95% is maintained, while ablations reducing the compositionality of keys (e.g., concatenation in place of outer product) produce a ≈5% drop.
- In meta-reinforcement learning on POMDP structured grids, Fast-WAM generalizes more robustly across tasks than larger LSTM controllers, which can overfit.
- On word-level language modeling benchmarks such as Penn Treebank and WikiText-2, Fast-WAM improves test perplexity by ∼3 on PTB and ∼2 on WT2 over a regularized AWD-LSTM baseline. Analysis of token predictivity reveals Fast-WAM especially reduces surprisal on rare proper-nouns following first mention.
The entire model footprint remains compact (e.g., 694K parameters for Fast-WAM vs. 1.1M for MNM and 10.5M for Transformer-XL on catbAbI).
5. Comparison to Prior Associative Memory Approaches
Fast-WAM embodies the class of fast-weight memory models originally introduced by Schmidhuber (1992), but with significant advances:
- Hetero-associative memory: Unlike Hopfield (auto-associative) networks, Fast-WAM supports arbitrary mapping in a single outer-product step.
- Slot-less design: In contrast to Neural Turing Machines or Memory Networks, Fast-WAM does not rely on addressing fixed memory slots, eliminating memory-allocation complexity and the risk of collision. All associations are represented implicitly in the dynamic weight tensor.
- Compositional inference chains: Multi-step readout (with ) supports chaining associations, enabling inferences such as transitivity, whereas slot memories and NTMs require explicit controller logic to simulate such chaining.
- Capacity and speed: With an outer-product tensor of size , the associator can, in principle, store up to orthogonal associations, compared to only as many as the number of slots in slot-based memories. Per-step computational and memory complexity is , which remains practical for in the range 32–64.
6. Training, Stability, and Implementation Details
Training proceeds by standard gradient descent over all slow matrices and parameters. During each sequence step, memory operations are fully differentiable, and the memory tensor is not a learned parameter but a mutable, in-graph state variable. Key implementation details to maintain stability include:
- Renormalization: After each write, is renormalized to prevent unbounded growth.
- Layer normalization: Applied to memory read outputs, stabilizing gradients and facilitating learning over variable-length sequences.
- Truncated backpropagation through time (tBPTT): Used for efficient training (e.g., 200 time steps on the catbAbI task).
The design enables highly data-efficient learning and robust generalization even in low-shot, compositional reasoning regimes.
7. Scientific and Practical Significance
Fast-WAM demonstrates that rapid, compositional hetero-associative memory mechanisms can efficiently augment recurrent models for structured inference and meta-learning. The architecture combines the compactness and simplicity of an LSTM controller with the operational flexibility of fast, compositional memory updates and multi-step relation chaining. Empirical findings establish that this combination substantially augments model capacity and generalization, outperforming both traditional recurrent models and slot-attention external memories on compositional tasks and challenging meta-RL domains (Schlag et al., 2020).
A plausible implication is that such architectures will see further adoption in neural models that require high data efficiency and the capacity to reason compositionally over sequences, including future directions in continual learning, rapid adaptation, and scalable, memory-augmented language modeling.