RecurrentGemma-2B: Griffin Hybrid Language Model
- RecurrentGemma-2B is defined by a Griffin architecture that interleaves a linear recurrent layer, local self-attention, and an MLP for efficient long-sequence processing.
- It achieves competitive results on language and safety benchmarks despite using fewer tokens, thanks to its fixed-size state and optimized inference cost.
- The model undergoes a two-phase pre-training followed by instruction tuning and RLHF, enhancing both safety and instruction compliance.
RecurrentGemma-2B is an open LLM representing a departure from the Transformer-centric paradigm, utilizing the Griffin architecture that combines linear recurrence and local self-attention. The model dispenses with global attention mechanisms and leverages a fixed-size state to facilitate efficient inference on long sequences. Despite being trained on fewer tokens than comparable Transformer-based models, RecurrentGemma-2B achieves competitive results across a range of natural language and safety benchmarks, and supports both pre-trained and instruction-tuned variants (Botev et al., 2024).
1. Griffin Hybrid Architecture
RecurrentGemma-2B is constructed on the Griffin block, which systematically interleaves three core sublayers per layer: a linear recurrent layer (RG-LRU), a windowed local self-attention mechanism, and a multi-layer perceptron (MLP). The RG-LRU maintains a fixed-size hidden state across arbitrary-length contexts, and local attention is restricted to a sliding window of the most recent tokens (with ).
- Linear Recurrence (RG-LRU): For each layer and time step , the recurrent state update is:
where is the hidden state, is the input at layer and time , and are trained matrices (). Practical implementations include layer normalization and residual connections.
- Local Self-Attention: After the recurrence, each layer attends to a limited context window, computing attention over :
with obtained from the input via projections.
- Fixed-Size State: Inference state comprises hidden vectors and, per layer, a -length ring buffer for key/value pairs. Total per-sequence memory scales as , independent of sequence length .
- MLP Sublayer: A two-layer feed-forward network with a hidden width of $3d$ implements nonlinearity via the GELU activation:
Standard residual and layer normalization connections are applied.
- Embedding Scaling: Input token embeddings are scaled by prior to the first layer. Output embeddings are tied but not rescaled.
This architectural design eliminates the need for a length-growing key/value cache, enabling constant per-token inference cost with long prompts, unlike classic Transformers (Botev et al., 2024).
2. Model Specification
RecurrentGemma-2B incorporates the following configuration parameters:
| Parameter | Value |
|---|---|
| Total parameters | 2.7B |
| Non-embedding params | 2.0B |
| Embedding params | 0.7B |
| Vocabulary size | 256,000 |
| Model width () | 2560 |
| Depth () | 26 |
| MLP expansion factor | 3 (hidden: 7680) |
| Attention heads | 10 (head dim: 256) |
| Local window () | 2048 |
Input and output embeddings are tied, no weight decay is applied to the RG-LRU parameters, and gradient norms through the multiplier are clipped to a maximum of 1000 for numerical stability. The tokenizer is SentencePiece with a 256k vocabulary.
3. Pre-training Regimen
Pre-training follows a two-phase approach identical to Gemma-2B, with RecurrentGemma-2B exposed to 2 trillion tokens (vs. 3T for Gemma-2B):
- Phase 1: Training on a broad mixture (web, code, math) sampled from filtered English corpora, enabling long context streaming to sequence length 8192.
- Phase 2: Continued training on a smaller, higher-quality corpus subset.
Key settings include a tokenizer with 256k vocabulary, up to sequence length 8192, and embedding scaling. RG-LRU parameters are exempt from weight decay, and -related gradients are clipped.
Training details such as batch size, learning rate, and optimizer follow the Gemma-2B protocol (referred to as "Gemma 2024"). This reduced data regime—2T tokens—underscores the model’s efficiency, as it achieves competitive downstream metrics relative to Transformer baselines trained on more data (Botev et al., 2024).
4. Instruction Tuning and RLHF
Post pre-training, RecurrentGemma-2B undergoes:
- Supervised Fine-Tuning (SFT): On human-written instruction–response pairs.
- RLHF: Using a preference model trained on human feedback and the RLHF algorithm from Gemma 2024.
A strict dialogue format is enforced via control tokens:
<start_of_turn>user ... <end_of_turn><start_of_turn>model ... <end_of_turn>
For example:
1 2 |
User: <start_of_turn>user Knock knock.<end_of_turn> Model: <start_of_turn>model Who’s there?<end_of_turn> |
5. Inference Efficiency
The fixed-size state architecture confers significant inference speed and memory advantages versus standard Transformer models:
- Memory: Per-token memory is bounded at , instead of for a Transformer with full-length K/V cache (where is sequence length).
- Time Complexity: Each new token's generation is for RecurrentGemma-2B compared to for full attention Transformers.
Benchmarks on TPUv5e (for prompt encoding up to 8k tokens in batch) report throughput of approximately 40,000 tokens/second for both Gemma-2B and RecurrentGemma-2B. For autoregressive sampling from a 2k prompt:
- RecurrentGemma-2B achieves roughly 6,000 tokens/sec, invariant to prompt length.
- Gemma-2B starts at 4,000 tokens/sec with prompt length-dependent degradation due to cache overhead.
A plausible implication is that on GPU or PyTorch backends, users should expect a similar relative 1.5–2× speedup for long-sequence inference, though with lower absolute throughput (Botev et al., 2024).
6. Empirical Evaluation
Performance across a spectrum of downstream and safety benchmarks demonstrates that RecurrentGemma-2B matches or nearly matches the results of similarly-sized Transformer models on a more data-efficient basis.
A. Language and Reasoning Tasks (selected results)
| Task | Gemma-2B | RecurrentGemma-2B |
|---|---|---|
| MMLU (5-shot) | 42.3 | 38.4 |
| HellaSwag | 71.4 | 71.0 |
| PIQA | 77.3 | 78.5 |
| SIQA | 49.7 | 51.8 |
| BoolQ | 69.4 | 71.3 |
| Average | 45.0 | 44.6 |
RecurrentGemma-2B performs comparably to Gemma-2B, occasionally outperforming on tasks such as PIQA, SIQA, and BoolQ.
B. Safety Benchmarks
| Benchmark | Pre-trained | Instr.-Tuned |
|---|---|---|
| RealToxicity (avg) | 9.8 | 7.6 |
| BOLD | 39.3 | 52.3 |
| TruthfulQA | 35.1 | 42.7 |
| Toxigen | 56.7 | 50.0 |
Instruction tuning improves safety and bias/fairness metrics, reducing toxicity and increasing fairness and truthfulness scores.
C. Human A/B Testing vs. Mistral 7B v0.2
| Task | RG-2B-IT Win Rate |
|---|---|
| Instruction-follow | 43.7% (95% CI [41.8,45.6]) |
| Safety | 59.8% (95% CI [57.1,62.6]) |
This suggests RecurrentGemma-2B-IT attains strong human-evaluated safety and instruction compliance when compared to larger parameter-count models (Botev et al., 2024).
7. Context and Implications
RecurrentGemma-2B demonstrates that the Griffin hybrid architecture—marrying linear recurrence for prefix compression with sliding-window attention—yields competitive results to standard Transformers while significantly reducing context memory requirements and inference cost. The model requires only two-thirds the pre-training data of its Transformer baseline and achieves a substantial inference acceleration on long prompts, supporting high-throughput applications. A plausible implication is potential for this class of models in deployment scenarios where memory and latency constraints dominate, or where long-context capabilities are needed efficiently (Botev et al., 2024).