Gated Slot Attention: Efficient Sequence Modeling
- Gated Slot Attention is a memory-efficient recurrent sequence model that integrates slot-wise gating for adaptive forgetting and context-aware memory reading.
- It employs a two-layer GLA structure with an intervening softmax, enabling constant-memory inference and linear-time training through effective slot updates.
- GSA bridges Transformer-to-RNN finetuning gaps by retaining inductive biases while reducing computational overhead and enhancing recall in resource-constrained settings.
Gated Slot Attention (GSA) is a memory-efficient, recurrent sequence modeling mechanism that integrates slot-wise gating into a two-pass linear attention framework. It augments Attention with Bounded-memory-Control (ABC) by incorporating a data-dependent forgetting mechanism, drawing on Gated Linear Attention (GLA). GSA leverages a two-layer GLA structure with an intervening , enabling context-aware memory reading and adaptive slot forgetting, while maintaining a compact recurrent state. This structure provides linear-time training and constant-memory inference, making GSA suitable for both in-context recall tasks and finetuning large pretrained Transformers to recurrent neural networks (RNNs) with minimal retraining overhead (Zhang et al., 2024).
1. Architectural Foundations and Relation to Predecessors
GSA merges architectural elements from ABC and GLA:
- ABC (Attention with Bounded-memory-Control) reformulates attention as two sequential linear-attention passes joined by a activation, yielding a memory-efficient design but lacking explicit forgetting.
- GLA (Gated Linear Attention) introduces data-dependent, slot-wise forget gates to linear attention, enabling dynamic control over slot retention but without the inductive bias provided by .
GSA inherits ABC’s two-pass structure and replaces each pass with a GLA module. In each GSA layer, input is projected (with a linear layer followed by Swish activation) to produce , , . Slot-wise forget gates , computed as with , modulate slot updates. The layer performs two GLA passes:
- The “key–slot pass” uses 0, 1, 2 to produce an intermediate 3.
- The “value–slot pass” uses 4, 5, 6 to yield final 7.
This enables efficient training through chunkwise matmul scheduling and constant-memory recurrent inference (Zhang et al., 2024).
2. Update Equations and Slot Memory Dynamics
GSA maintains two slot matrices 8, where 9 is the number of slots and 0 the embedding dimension. Slot updates and memory reading are governed by per-slot gates 1:
2
Alternatively, GSA can be formulated as two GLA passes:
3
Within a GLA pass at step 4:
5
with input and forget gates 6, 7 determined contextually. The gates enable precise, data-driven control over slot memory persistence and overwrite.
3. Context-Aware Reading and Adaptive Forgetting Mechanisms
GSA’s memory operations consist of two complementary mechanisms:
- Context-aware reading: In the first GLA pass, 8 attends over the decayed key-slots 9, producing 0—a context-integrated representation aggregating relevant historical information. The subsequent 1 sharpens attention in the second pass, enhancing selective retrieval and mitigating the “attention dilution” common to purely linear kernels.
- Adaptive forgetting: The slot-wise gate 2 modulates the decay of each slot 3 at every step. Small 4 values result in rapid forgetting (overwrite) by the new 5, while larger values promote memory retention. Since 6 is computed as a function of 7, GSA dynamically adapts memory span per-token.
This dual mechanism provides an implicit memory capacity that can exceed the explicit slot count 8, leveraging the exponential separation properties of the two-layer structure with softmax (cf. modern Hopfield networks). Despite this, recurrent state size remains bounded at 9 per head (Zhang et al., 2024).
4. Role of Softmax in T2R (Transformer-to-RNN) Finetuning
GSA retains a 0 operation between its two GLA passes, which is crucial in T2R finetuning:
- T2R finetuning involves initializing a linear-RNN model from pretrained Transformer weights, minimizing retraining data requirements (∼1–3% of full data).
- Linear attention models such as RetNet and GLA lack a 1, causing inductive bias mismatches when adapted from softmax-based Transformer weights.
- By preserving a single 2 between GLA passes, GSA and ABC bridge this gap, retaining the underlying attention inductive bias.
Empirical results with Mistral-7B indicate that GSA, when finetuned with 20B tokens, achieves ∼53.9% average accuracy, surpassing RetNet (47.1%) and GLA (52.5%) and approaching the accuracy of much larger RNNs (∼56.9% with 100B tokens) (Zhang et al., 2024). This suggests that GSA is especially effective in resource-constrained T2R regimes.
5. Hardware-Efficiency and Empirical Evaluation
Each GSA pass utilizes GLA, permitting efficient training and inference:
- Training efficiency: Chunkwise parallel scan algorithm (from FLA) leads to 3 per-step memory and linear scaling in sequence length 4 for both time and memory. On an NVIDIA H800, GSA achieves ∼44K tokens/sec throughput with 16K-token batches, nearly matching GLA’s throughput, with marginally higher peak memory (∼38 GiB vs ∼36 GiB).
- Inference efficiency: The model supports fully recurrent (constant-memory) inference, with smaller state sizes—1285 for GSA vs 2566 for GLA—resulting in 10–15% faster autoregressive decoding.
Empirical results demonstrate competitive or superior performance:
- Recall-intensive benchmarks: On associative recall (MQAR: 512-length, 64-pair, slot size 128), GSA achieves ∼99% accuracy, exceeding GLA and Mamba (∼98.7%).
- Real-world retrieval & QA: On tasks such as FDA, SWDE, SQuAD, NQ, TriviaQA, and DROP, GSA (1.3B model) attains ∼31.8% average, ahead of GLA (31.4%), RetNet (28.4%), and Mamba (27.7%).
- Language modeling & zero-shot: GSA matches or slightly underperforms leading gated baselines (HGRN2) on Lambada/WikiText, but with half the recurrent state size.
6. Summary, Capacity, and Implications
Gated Slot Attention combines ABC’s softmax-linked, two-pass architecture with GLA’s data-driven slot forgetting. The result is a recurrent neural network that:
- Maintains inductive biases vital for Transformer-to-RNN finetuning.
- Dynamically regulates memory retention and erasure in bounded slots.
- Exploits hardware-optimized matrix multiplication routines.
- Narrows the performance gap with full Transformers in recall-heavy tasks, while achieving substantial reductions in inference state size and memory usage (Zhang et al., 2024).
A plausible implication is that GSA’s exponential memory separation, afforded by its two-pass and gating structure with softmax, represents a significant advancement in efficient sequence modeling and in-context recall without incurring the resource burden of conventional Transformer architectures.