Jamba-7B: Hybrid LLM with Mamba and MoE
- Jamba-7B is a hybrid large language model that combines Transformer self-attention layers with Mamba state-space layers and sparse MoE modules to optimize performance and scalability.
- It achieves state-of-the-art accuracy on academic benchmarks and long-context tasks, processing up to 256K tokens on an 80GB GPU with up to 3× inference speed improvements.
- Its innovative design, featuring interleaved layers and integrated MoE with RMSNorm stabilization, significantly reduces KV-cache memory requirements and boosts efficiency.
Jamba-7B is a LLM employing a hybrid architecture that interleaves Transformer self-attention layers with Mamba state-space layers, augmented by sparse Mixture-of-Experts (MoE) modules. In its principal configuration, it features a "7B" base spine yielding 12 billion active parameters per forward pass and a total of 52 billion available parameters contributed by the full MoE expert pool. Jamba-7B achieves high throughput, supports 256,000-token contexts on a single 80GB GPU, and demonstrates state-of-the-art accuracy on standard academic language modeling tasks and long-context benchmarks (Lieber et al., 2024).
1. Architectural Composition
Jamba-7B consists of four repeated "Jamba blocks", each comprising eight layers with a fixed interleaving pattern: one Transformer (self-attention plus MLP) layer followed by seven Mamba (state-space model plus MLP) layers, giving a total depth of 32 layers—4 Transformer and 28 Mamba. MoE is integrated in every alternate layer’s MLP module, replacing the standard feedforward with a sparse MoE composed of 16 experts; a standard top-2 router dispatches each token to two experts, employing load-balancing and gating techniques following the Switch-Transformer paradigm. "Active parameters" refer to those used per forward path, including the dynamically selected experts, while "total available" counts the full pool:
| Component | Active parameters | Total (available) parameters |
|---|---|---|
| Jamba-7B (overall) | 12B | 52B |
MoE gating follows conventional routing: for token representation , logits are computed, top- entries (here, ) selected, normalization applied via softmax, and activation is dispatched and aggregated from the corresponding experts (Lieber et al., 2024).
2. Computational and Memory Efficiency
Jamba-7B exhibits markedly reduced KV-cache memory requirements and high inference throughput at long sequence lengths. At a context of 256,000 tokens (16-bit storage), Jamba-7B requires only 4 GB of KV-cache, compared to 128 GB for Llama-2 (6.7B), 32 GB for Mistral (7.2B), and 32 GB for Mixtral (46.7B, 12.9B active). This efficiency arises primarily from the replacement of the majority of attention layers with Mamba SSMs, which use linear-time recurrence and thus require no cache scaling as with sequence length.
In throughput benchmarks (single 80GB A100, INT8, 8K context, 512 tokens generated), Jamba-7B enables larger batch sizes and achieves up to 3× the throughput of Mixtral-8×7B. At 128K-token context, it remains 3× faster than Mixtral, with Llama-2 70B unable to fit the context at all (Lieber et al., 2024).
3. Training Methodology and Hyperparameters
Training utilizes a proprietary, in-house corpus containing web text, books, and code, processed with quality and deduplication filters. The tokenizer employs 64K BPE units, splits digits into separate tokens, and omits dummy leading spaces.
Training infrastructure includes NVIDIA H100 GPUs with Fully-Sharded Data Parallel (FSDP), tensor, sequence, and expert parallelism to support large-scale hybrid and MoE architectures. The paper does not provide optimizer specifics or learning-rate schedules, but notes that RMS normalization (RMSNorm) is critical for stabilizing Mamba modules during training; insertion of RMSNorm inside each Mamba block's internal activations was essential to avoid loss spikes (Lieber et al., 2024).
4. Performance on Benchmarks
Jamba-7B attains leading performance among models of comparable (or much greater) size on standard academic benchmarks. On zero-shot and few-shot tasks, it matches or exceeds models such as Mixtral-8×7B (46.7B/12.9B), Llama-2 70B, and Gemma 7B, while maintaining higher efficiency:
| Model | Active Params | HellaSwag | WinoGrande | PIQA | MMLU | GSM8K | BoolQ |
|---|---|---|---|---|---|---|---|
| Jamba-7B | 12B | 87.1 | 82.5 | 83.2 | 67.4 | 59.9 | 88.2 |
| Mixtral | 12.9B | 86.7 | 81.2 | 83.0 | 70.6 | 60.4 | 88.4 |
| Llama-2 70B | 70B | 85.3 | 80.2 | 82.8 | 69.8 | 55.3 | 85.0 |
| Gemma 7B | 7B | 81.2 | 72.3 | 81.2 | 64.3 | 54.5 | 87.2 |
In long-context "needle-in-a-haystack" tasks, Jamba-7B retains over 95% exact-match recall at up to 256K token positions despite having only four attention layers. For naturalistic QA (3-shot, context lengths up to 62K), it slightly outperforms Mixtral on aggregate F1 (0.44 vs. 0.43) while operating at triple the speed (Lieber et al., 2024).
5. Architectural Ablations and Insights
Ablation studies confirm that the hybridization strategy is critical for both efficiency and quality. At small (1.3B) and moderate (7B) scale, 1:7 and 1:3 attention:Mamba hybrids outperform pure attention or pure Mamba alternatives on all measured tasks. MoE addition (16 experts, top-2, every second layer) produces consistent accuracy gains across tasks (e.g., HellaSwag: 66.0 with MoE vs. 62.5 without), and in-context learning (ICL)-sensitive tasks outright fail in pure Mamba networks but are fully recovered by the hybrid, attributed in part to the emergence of copy-induction patterns within the sparse attention layers.
Additional findings include:
- Internal normalization (RMSNorm) is required inside Mamba blocks for loss stability in large-scale hybrid models.
- Explicit positional encodings (e.g., RoPE) provide negligible benefit, as Mamba layers alone suffice to encode relative position information (Lieber et al., 2024).
6. Context and Significance
Jamba-7B represents the first large-scale, production-grade hybrid of Transformer attention, Mamba SSM layers, and sparse MoE. It delivers the empirical accuracy of 70B+ parameter pure-transformer models, accommodates extreme context lengths within attainable hardware budgets (256K tokens on 80GB GPU), and provides up to 3× inference throughput on long-form generation workloads. Support for both standard and long-context benchmarks, paired with easily tunable architectural modularity, positions Jamba-7B as an efficient, flexible foundation for broad LLM research, especially in settings where high throughput and context capacity are limiting constraints. Checkpoints for multiple ablation configurations have been released to encourage further analysis and architectural refinement (Lieber et al., 2024).