Recurrent Weighted Average (RWA)
- Recurrent Weighted Average (RWA) is an RNN architecture that computes a dynamic running weighted average over past inputs, integrating attention into its recurrence.
- It updates cumulative numerator and denominator in constant time (O(1)) per step, enabling efficient handling of long-range dependencies without costly retrospective attention.
- Empirical evaluations show RWA converges faster and uses fewer parameters than LSTM on sequential tasks, though it faces challenges when tasks require revising past contributions.
The Recurrent Weighted Average (RWA) is a recurrent neural network (RNN) architecture that integrates an attention-like, running weighted average computation directly into the recurrence, permitting efficient, -per-step processing with access to the full historical context. Unlike conventional RNNs such as LSTM or GRU, which propagate information through a strictly sequential, one-step-at-a-time mechanism, the RWA model permits each new hidden state to be a direct, dynamically reweighted summary of all previously observed inputs, providing a powerful mechanism for capturing long-range dependencies in sequential data (Ostmeyer et al., 2017, Maginnis et al., 2017).
1. Motivation and Conceptual Distinction
Traditional RNNs—including LSTM and GRU—update their hidden state exclusively from the immediately preceding state and current input. Information from earlier in the sequence must traverse a chain of potentially hundreds or thousands of steps, causing severe challenges with vanishing and exploding gradients. Attention mechanisms, introduced to address this, compute a weighted sum over all intermediate states but are typically applied post hoc and entail computational cost for sequences of length .
The central insight of the RWA architecture is to "bake in" the attention mechanism within the recurrence relation. Each processing step maintains a running, weighted average across historical feature encodings, parameterized by dynamically computed attention weights. The essential innovation is that this computation can be maintained as two running sums, thereby yielding update cost per step—matching the efficiency of classical RNN units, yet with a representational capacity strongly reminiscent of global attention (Ostmeyer et al., 2017, Maginnis et al., 2017).
2. Mathematical Formulation
Given an input sequence , initial state with and learned , the recurrent update equations at timestep are as follows:
Here 0 denotes feature concatenation, 1 indicates element-wise multiplication, and all attention is performed through the exponentiated attention logits 2. This formulation enables the computation of 3 as a nonlinear function of a weighted average over all prior signed encodings 4, with weights determined by locally computed attention scores.
All architectural parameters are standard dense layer weights and biases: 5, 6, 7, 8, and 9, with 0 typically set to 250 hidden units (Ostmeyer et al., 2017, Maginnis et al., 2017).
3. Algorithmic Properties and Computational Complexity
Unlike explicit attention mechanisms that require retaining and reaccessing all past activations (1 per step, 2 over a full sequence), the RWA model's numerator and denominator can be updated incrementally, as:
3
Thus, the per-step runtime and memory do not scale with the number of processed timesteps, but are constant and determined solely by the hidden dimension 4. This makes RWA highly suitable for very long or streaming sequences (Ostmeyer et al., 2017, Maginnis et al., 2017).
To ensure numerical stability for large exponents or long sequences (where 5 may overflow), maintaining a running maximum of 6 and rescaling the numerator and denominator in lockstep is recommended (see Appendix B in (Ostmeyer et al., 2017)).
4. Empirical Performance on Benchmark Tasks
The RWA model has been systematically benchmarked against LSTM across a suite of synthetic and real-world sequence tasks, always keeping both models at identical hidden size (7), initialization, and optimization settings (Ostmeyer et al., 2017):
| Task | Metric/Result | RWA Performance | LSTM Performance |
|---|---|---|---|
| Artificial Grammar | Steps to 100% accuracy | ~600 | ~1,000 |
| Sequence Length | Steps to near-perfect accuracy | ~100 | ~2,000 |
| Variable Copy (T=100) | Steps to beat baseline | ~1,000 | ~10,000 |
| Variable Copy (T=1000) | Steps to beat baseline | ~3,000 | barely after ~50,000 |
| Adding Problem (T=100) | Steps to beat baseline MSE | ~1,000 | ~3,000 |
| Adding Problem (T=1000) | Steps to beat baseline MSE | ~1,000 | ~15,000 |
| MNIST Sequential | Test Accuracy after 250k steps | 98.1% (unpermuted), 93.5% (permuted) | 99.0% (unpermuted), 93.6% (permuted) |
Across multiple experiments, RWA demonstrates accelerated convergence, strong performance on long-range dependency tasks, and achieves or exceeds LSTM in final accuracy or error in most single-output, aggregate-global-sequence settings. It consistently uses approximately 25% fewer parameters per hidden unit than LSTM, with identical 8 per-step cost (Ostmeyer et al., 2017).
On multi-copy or character-level modeling tasks requiring the ability to forget or revise the influence of earlier inputs, RWA shows marked limitations, failing to learn the multi-copy task or performing poorly on the Wikipedia bits-per-character task (9 bpc for RWA vs 0 for LSTM/GRU) (Maginnis et al., 2017).
5. Relationship to Attention Mechanisms
Classical attention computes a weighted sum over a non-recurrent collection of encoder outputs, requiring revisiting the entire sequence for every decoding step—which is computationally infeasible for long or streaming data. RWA replaces this paradigm with a recurrent, running normalization, where attention weights are integrated into the state update and relevance for past steps is accumulative, not retroactively adjustable.
Specifically, RWA's log-attention 1 is computed on-the-fly from current input and previous hidden state, then exponentiated and used to reweight the signed encoding 2 in the global average. While this allows direct, fast feedback from every previous step, it also means that once an input's weight is accumulated, it cannot be discounted or revised. This behavior is sharply contrasted with classical attention, which is flexible but incurs high cost (Ostmeyer et al., 2017, Maginnis et al., 2017).
6. Architecture, Initialization, and Implementation
- Input handling: 3, supports one‐hot or real-valued encodings.
- Activations: 4 for nonlinearity; gating with 5; signed encoding via elementwise product.
- Initialization: 6, 7, other weights from uniform in 8.
- Optimization: Adam (9), batch size 100, gradient clipping not required in original work but proposed if instability encountered.
- Output: Fully connected layer on 0 or 1 with cross-entropy for classification, MSE for regression (Ostmeyer et al., 2017).
7. Limitations and Subsequent Developments
By construction, RWA cannot forget or revise the weights assigned to earlier timesteps; the running denominator 2 is strictly increasing, and altering the influence of early points would require exponentially larger weights for later points. This limitation prevents RWA from effectively handling tasks involving multiple, consecutive sub-tasks or those where local, recent context is essential (e.g., repeated copy, next-symbol prediction in language modeling).
To address these drawbacks, the Recurrent Discounted Attention (RDA) unit extends RWA by introducing a discount gate, permitting the model to actively reduce the impact of previous states and support complex, multi-output or forgetting-required scenarios with higher efficiency and accuracy (Maginnis et al., 2017).
Further open directions for RWA include bidirectional architectures, stacked hierarchies, autoencoding, and applications in real-world domains such as NLP, genomics, and generative music, though systematic evaluation on large-scale, naturally occurring datasets remains to be undertaken (Ostmeyer et al., 2017).