RNNFormer Block Architecture
- RNNFormer Block is a neural module that combines transformer self-attention with recurrent gating, enabling block-parallel processing with linear computational complexity.
- It processes blocks of tokens in parallel using vertical self- and cross-attention while employing horizontal recurrence over state vectors to capture long-range dependencies.
- Empirical results demonstrate improved language modeling performance and reduced bits-per-token on benchmarks, making it a scalable alternative to standard transformers.
The RNNFormer block, also known as the Block-Recurrent Transformer cell, is a neural architecture designed to combine the strengths of transformer layers and recurrent architectures for sequence modeling. Unlike conventional transformers, which have quadratic complexity in sequence length, and standard LSTMs, which process one token at a time, the RNNFormer processes blocks of tokens in parallel while maintaining a recurrent state over blocks. This approach yields linear complexity with respect to sequence length and is capable of efficiently leveraging accelerator hardware. The RNNFormer block utilizes both self-attention and cross-attention mechanisms within and across token blocks, as well as LSTM-style or highway-style gating for state updates, resulting in enhanced modeling capacity over long contexts and improved empirical performance on language modeling tasks (Hutchins et al., 2022).
1. RNNFormer Block Architecture
The RNNFormer block operates on non-overlapping blocks of tokens of length . At each block index , the model maintains a set of state vectors and receives block token embeddings . The architecture executes two principal computational passes within every block:
- Vertical pass (within-block, parallel over tokens):
- Self-attention among token embeddings.
- Cross-attention from tokens to the persistent state vectors.
- Concatenation of both attention outputs, followed by projection and a small MLP.
- (Optionally) A vertical gate and residual connection to update the token outputs.
- Horizontal pass (across blocks, recurrent over state):
- Self-attention among the state vectors.
- Cross-attention from states to the current token block embeddings.
- Concatenation, projection, MLP, and an LSTM-style or highway-style gate to produce the next block’s state .
This block is structurally derived by rotating the transformer layer architecture 90°, replacing residuals in the horizontal direction with gates, and unrolling across blocks to propagate long-term dependencies.
2. Precise Formulation and Gating Mechanisms
2.1 Attention Computations
Let denote the hidden dimension. The model shares key and value projections between vertical and horizontal passes but maintains separate query projections. At each block , the following operations are performed (see equations (1)–(4) in (Hutchins et al., 2022)):
- Token self-attention (vertical):
- Token State cross-attention (vertical):
- State self-attention (horizontal):
- State Token cross-attention (horizontal):
The outputs are concatenated: and , then projected and passed to MLPs.
2.2 Gating Updates
The core state update applies either an LSTM-style gating mechanism or a fixed highway-gate:
- LSTM-style gating (Eq. 5):
where is sigmoid, is elementwise multiplication.
- Fixed-gate (highway-style, Eq. 6):
2.3 Recurrence Relation
The overall cell recurrence is:
3. Algorithmic Workflow
The RNNFormer cell is processed over a sequence partitioned into blocks, with pseudo-code as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
Inputs: token blocks {e₁,…,e_{N/W} ∈ ℝ^{W×d}, initial state s₁ ∈ ℝ^{S×d}}
For t = 1 … N/W do
# Vertical pass (parallel over W tokens)
Qev,Kev,Vev ← Proj_e^v(eₜ)
Qsv,Ksv,Vsv ← Proj_s^v(sₜ)
Aev = Softmax((Qev·Kevᵀ + RelPosBias)/√d)·Vev
Cev = Softmax((Qev·Ksvᵀ)/√d)·Vsv
yₜ = eₜ + Gate_v( MLP_v(concat[Aev,Cev]) )
# Horizontal (recurrent) pass
Qsh,Ksh,Vsh ← Proj_s^h(sₜ)
Qeh,Keh,Veh ← Proj_e^h(eₜ)
Ash = Softmax((Qsh·Kshᵀ)/√d)·Vsh
Csh = Softmax((Qsh·Kehᵀ)/√d)·Veh
uₜ = MLP_h( concat[Ash,Csh] )
s_{t+1} = LSTM_Gate(uₜ, cₜ)
EndFor
Cache s_{N/W+1} for next segment |
At inference, block-level processing enables autoregressive decoding with cached recurrent and transformer key–value states.
4. Computational Complexity
For a sequence of length , block size , and state vectors:
- Within each block:
- Token self-attention:
- State self-attention:
- Cross-attention (tokenstate):
- Total for blocks:
With , per-block complexity is ; total complexity is , linear in . For comparison, a standard transformer layer incurs cost.
5. Comparison to Transformer and LSTM Architectures
| Model | Parallelism | Memory/State | Complexity |
|---|---|---|---|
| Transformer | Fully parallel | Key–value cache only | |
| LSTM | Sequential | Single -dim vector | Linear but compresses all |
| RNNFormer Block | Block-parallel | vectors | , linear |
Parameter count and FLOPs for a 12-layer Block-Recurrent model are nearly equivalent to a 13-layer Transformer-XL, as the block-recurrent cell adds the same number of projection/MLP parameters as an additional transformer layer.
The RNNFormer block enables:
- Persistent state vectors carried across blocks (much larger than LSTM memory).
- Parallel token processing within blocks.
- Long-context modeling with computational cost linear in sequence length.
6. Empirical Performance and Scaling
On language modeling benchmarks (PG19, arXiv, GitHub), the "Rec:fixed:skip" Block-Recurrent configuration demonstrates significant reductions in bits-per-token (log perplexity) compared to Transformer-XL baselines—at equal step-time and parameter budget.
Selected Results:
| Model | Seg. len | Win. | PG19 tok. | arXiv tok. | GitHub tok. | Rel. step time |
|---|---|---|---|---|---|---|
| XL:512 | 512 | 512 | 3.62 | 1.45 | 1.21 | - |
| XL:2048 | 2048 | 2048 | 3.58 | 1.31 | 1.01 | - |
| Slide:13L | 4096 | 512 | 3.58 | 1.42 | 1.17 | 1.00 |
| Rec:fixed:skip | 4096 | 512 | 3.53 | 1.24 | 0.976 | 1.00 |
On PG19, adding recurrence reduces bits-per-token from 3.58 to 3.53. On arXiv tokens, perplexity drops from 1.42 to 1.24, and on GitHub tokens from 1.17 to 0.976, all at comparable computation and parameter count. Scaling studies indicate that recurrence affords improvements commensurate with doubling the model’s parameter count across multiple model sizes (Hutchins et al., 2022).
7. Contextual Impact and Implementation
The RNNFormer block achieves efficient long-context modeling without the quadratic overhead of full-sequence attention, and can be directly implemented by modifying conventional transformer layer code—suggesting practical adoption routes for large-scale sequence modeling tasks requiring persistent memory and linear scalability. The approach was demonstrated effective on book, scientific paper, and source code corpora with open-source implementation available (Hutchins et al., 2022).