Papers
Topics
Authors
Recent
Search
2000 character limit reached

Fast-WAM: Fast-Weight Associative Memory

Updated 24 March 2026
  • Fast-WAM is a neural architecture that integrates an LSTM controller with a fast-weight memory to support rapid, heterogeneous associative binding.
  • It uses a third-order tensor for storing key-value mappings via outer products, enabling multi-step, compositional inference over sequential data.
  • Empirical results show Fast-WAM outperforms traditional models on structured reasoning and meta-reinforcement learning while maintaining a compact parameter footprint.

Fast-Weight Associative Memory (Fast-WAM) refers to a neural network architecture that augments standard recurrent models with a rapidly updated, compositional associative memory. The design draws upon principles of fast weights and hetero-associative binding, enabling efficient storage and inference over arbitrary associations between keys and values within sequential data. Fast-WAM models have demonstrated significant empirical gains on structured reasoning benchmarks, meta-reinforcement learning for partially observable Markov decision processes, and challenging compositional language tasks, all while maintaining a compact parameter footprint (Schlag et al., 2020).

1. Architectural Overview

Fast-WAM consists of two principal components: a slow-weight controller, typically a Long Short-Term Memory (LSTM) network, and a fast-weight memory (FWM) implemented as a third-order tensor. The LSTM provides a recurrent hidden state htRdLSTMh_t \in \mathbb{R}^{d_\mathrm{LSTM}} and cell state ctc_t, updated per standard recurrence equations; these serve as the basis for generating both the memory write and read signals.

The fast-weight memory AtRdF×dF×dFA_t \in \mathbb{R}^{d_F \times d_F \times d_F} is reshaped for computational convenience as AtR(dF2)×dFA_t \in \mathbb{R}^{(d_F^2) \times d_F}, providing storage for hetero-associative mappings between key-pairs and value vectors. At each time step, the system performs:

  • A write: binding a new (k1t,k2t)vt(k_{1t},k_{2t}) \mapsto v_t association.
  • A read: one or more inference steps to retrieve values via content-based lookup.

All slow weights are learned by gradient descent, and the system trains end-to-end by backpropagation through time.

2. Memory Update and Retrieval Mechanisms

The fundamental mechanics of Fast-WAM center on the explicit, rapidly modifiable memory updates and compositional reading procedures.

Memory Write (Update) Rule

Upon receiving the LSTM hidden state hth_t, linear projections produce write-gate βt(0,1)\beta_t \in (0,1), keys k1t,k2tRdFk_{1t}, k_{2t} \in \mathbb{R}^{d_F}, and write-value vtRdFv_t \in \mathbb{R}^{d_F} via: [wβt,wwritet]=Wwriteht[w_{\beta t}, w_{\mathrm{write} t}] = W_{\mathrm{write}} h_t

βt=σ(wβt),[k1t;k2t;vt]=tanh(wwritet)\beta_t = \sigma(w_{\beta t}), \quad [k_{1t}; k_{2t}; v_t] = \tanh(w_{\mathrm{write} t})

The memory update first removes any previous association at the (k1t,k2t)(k_{1t}, k_{2t}) location, then inserts the new value: vold=At1vec(k1tk2t)v_\text{old} = A_{t-1} \cdot \mathrm{vec}(k_{1t} \otimes k_{2t})

At=At1+βtvec(k1tk2t)(vtvold)A'_t = A_{t-1} + \beta_t \cdot \mathrm{vec}(k_{1t} \otimes k_{2t}) \otimes (v_t - v_\text{old})

At=At/max(1,At2)A_t = A'_t / \max(1, \| A'_t \|_2)

This rule implements a single, gated Hebbian-like step to manage compositional bindings efficiently and avoid memory explosion.

Memory Read (Retrieval) Rule

A series of NrN_r content-based queries allow for multi-step inference: ut(0)=tanh(Wnht)u^{(0)}_t = \tanh(W_n h_t)

ut(i)=tanh(We(i)ht)u^{(i)}_t = \tanh(W^{(i)}_e h_t)

For i=1i=1 to NrN_r, the intermediate retrieved values are recursively computed: mt(i)=LayerNorm(Atvec(ut(i1)ut(i)))m^{(i)}_t = \mathrm{LayerNorm}(A_t \cdot \mathrm{vec}(u^{(i-1)}_t \otimes u^{(i)}_t)) The final retrieved value is projected and added residually to the controller output: mt=Womt(Nr),rt=ht+mtm_t = W_o m^{(N_r)}_t, \quad r_t = h_t + m_t A softmax operates on rtr_t to produce the predictive output.

3. End-to-End Workflow and Pseudocode

The following sequence details the full forward computation:

  1. Token Embedding: et=Embedding(xt)e_t = \text{Embedding}(x_t).
  2. LSTM Update: (ht,ct)=LSTM(et,ht1,ct1)(h_t, c_t) = \mathrm{LSTM}(e_t, h_{t-1}, c_{t-1}).
  3. Write to FWM: Compute write gate, keys, and value. Retrieve voldv_\text{old}. Apply the Hebbian update and renormalization.
  4. Multi-step Read: Generate Nr+1N_r+1 query vectors and perform NrN_r iterative content-based lookups.
  5. Output Computation: Add retrieved memory output to controller state, apply an output projection and softmax.

A succinct pseudocode illustration:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
(h, c) = 0, A = 0
for t in 1...T:
    e = Embedding[x_t]
    (h, c) = LSTM(e, h, c)
    # Write to fast weights
    [w_beta, w_w] = W_write @ h
    beta = sigmoid(w_beta)
    [k1, k2, v] = tanh(w_w)
    v_old = A @ vectorize(k1  k2)
    A_ = A + beta * vectorize(k1  k2)  (v - v_old)
    A = A_ / max(1, norm(A_))
    # Read from fast weights
    u0 = tanh(W_n @ h)
    for i in range(1, N_r+1):
        ui = tanh(W_e[i] @ h)
        m = LayerNorm(A @ vectorize(u_{i-1}  u_i))
    m = W_o @ m
    r = h + m
    y_{t+1} = softmax(W_out @ r)
This design provides a computationally tractable means of performing chained compositional inference (e.g., aba \rightarrow b, bcb \rightarrow c yields aca \rightarrow c by cascading lookups), a property inaccessible to conventional slot-based external memories.

4. Empirical Performance and Ablations

Fast-WAM achieves strong empirical results on challenging compositional reasoning and sequential inference tasks:

  • On concatenated-bAbI (catbAbI), Fast-WAM attains 96.8% QA-mode accuracy, surpassing LSTM (81%), Transformer-XL (87.7%), and Metalearned Neural Memory (89.0%). With Nr=1N_r=1 read, accuracy exceeding 95% is maintained, while ablations reducing the compositionality of keys (e.g., concatenation in place of outer product) produce a ≈5% drop.
  • In meta-reinforcement learning on POMDP structured grids, Fast-WAM generalizes more robustly across tasks than larger LSTM controllers, which can overfit.
  • On word-level language modeling benchmarks such as Penn Treebank and WikiText-2, Fast-WAM improves test perplexity by ∼3 on PTB and ∼2 on WT2 over a regularized AWD-LSTM baseline. Analysis of token predictivity reveals Fast-WAM especially reduces surprisal on rare proper-nouns following first mention.

The entire model footprint remains compact (e.g., 694K parameters for Fast-WAM vs. 1.1M for MNM and 10.5M for Transformer-XL on catbAbI).

5. Comparison to Prior Associative Memory Approaches

Fast-WAM embodies the class of fast-weight memory models originally introduced by Schmidhuber (1992), but with significant advances:

  • Hetero-associative memory: Unlike Hopfield (auto-associative) networks, Fast-WAM supports arbitrary kvk \rightarrow v mapping in a single outer-product step.
  • Slot-less design: In contrast to Neural Turing Machines or Memory Networks, Fast-WAM does not rely on addressing fixed memory slots, eliminating memory-allocation complexity and the risk of collision. All associations are represented implicitly in the dynamic weight tensor.
  • Compositional inference chains: Multi-step readout (with Nr>1N_r>1) supports chaining associations, enabling inferences such as transitivity, whereas slot memories and NTMs require explicit controller logic to simulate such chaining.
  • Capacity and speed: With an outer-product tensor of size dF×dFd_F \times d_F, the associator can, in principle, store up to dF2d_F^2 orthogonal associations, compared to only as many as the number of slots in slot-based memories. Per-step computational and memory complexity is O(dF3)O(d_F^3), which remains practical for dFd_F in the range 32–64.

6. Training, Stability, and Implementation Details

Training proceeds by standard gradient descent over all slow matrices and parameters. During each sequence step, memory operations are fully differentiable, and the memory tensor AtA_t is not a learned parameter but a mutable, in-graph state variable. Key implementation details to maintain stability include:

  • Renormalization: After each write, AtA_t is renormalized to prevent unbounded growth.
  • Layer normalization: Applied to memory read outputs, stabilizing gradients and facilitating learning over variable-length sequences.
  • Truncated backpropagation through time (tBPTT): Used for efficient training (e.g., 200 time steps on the catbAbI task).

The design enables highly data-efficient learning and robust generalization even in low-shot, compositional reasoning regimes.

7. Scientific and Practical Significance

Fast-WAM demonstrates that rapid, compositional hetero-associative memory mechanisms can efficiently augment recurrent models for structured inference and meta-learning. The architecture combines the compactness and simplicity of an LSTM controller with the operational flexibility of fast, compositional memory updates and multi-step relation chaining. Empirical findings establish that this combination substantially augments model capacity and generalization, outperforming both traditional recurrent models and slot-attention external memories on compositional tasks and challenging meta-RL domains (Schlag et al., 2020).

A plausible implication is that such architectures will see further adoption in neural models that require high data efficiency and the capacity to reason compositionally over sequences, including future directions in continual learning, rapid adaptation, and scalable, memory-augmented language modeling.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Fast-WAM.