Multi-Query Multi-Head Attention (MQMHA)
- MQMHA is a family of efficient attention architectures that generalizes classical multi-head attention through flexible sharing of key and value projections.
- It enables variations like multi-query and grouped-query attention, achieving significant latency reductions while retaining near MHA-quality performance with minimal uptraining.
- MQMHA is applied in language modeling and speaker verification, reducing memory bandwidth and FLOPs through optimized grouping and conversion strategies.
Multi-Query Multi-Head Attention (MQMHA) is a family of efficient attention architectures that generalizes classical multi-head attention (MHA) by allowing flexible sharing of key and value projections across queries. Originally introduced to address the bandwidth and latency bottlenecks of autoregressive Transformer inference, MQMHA—particularly its grouped-query forms—balances quality retention and hardware efficiency. Architectures in this family include standard MHA ( query heads and separate key-value heads), multi-query attention (MQA; query heads, but a single shared key-value head), and grouped-query attention (GQA; query heads, intermediate key-value heads, $1 < g < h$). MQMHA is also instantiated in pooling modules for sequence embedding, particularly in speaker verification, where multiple learnable queries per head are used to aggregate feature statistics.
1. Formalization and Mathematical Structure
MQMHA defines the number of query heads as and the number of key-value heads as (). For token representations , standard MHA involves:
0
with 1, reshaped into 2 sets of 3-dimensional vectors per head (4). Each head computes: 5 and the outputs are concatenated and linearly projected.
In MQA, all query heads have distinct projections (6) but share key and value projections: 7
8
GQA generalizes MQA by partitioning query heads into 9 disjoint groups, each sharing its own key-value projection. For group assignment 0 for head 1: 2
3
4
Special cases: 5 yields MQA; 6 recovers standard MHA (Ainslie et al., 2023, Chen et al., 2024).
2. Derivation and Algorithmic Workflow
The essential workflow of grouped-query attention consists of:
- Query Projection: For each head 7, 8 is computed via 9.
- Key-Value Projection and Grouping: Compute 0 distinct 1 as 2 for 3. Assign each query head to a group; the mapping 4 can be fixed (neighbor grouping, even-sized groups) or data-informed (asymmetric, activation-informed grouping).
- Attention Computation: Each head 5 attends to its group’s key and value via the standard scaled dot-product formula.
- Aggregation and Output: 6 outputs are stacked and passed through an output projection.
Activation-informed grouping (e.g., AsymGQA) uses head activation similarity over a calibration set to construct groupings that maximize downstream accuracy, using a stochastic search and brief fine-tuning passes (Chen et al., 2024).
3. Uptraining and Conversion Methods
For LLMs already trained with MHA, conversion to GQA or MQA is achieved by:
- Mean-Pooling Conversion: For each GQA group, mean-pool the original 7 and 8 weights within the group to initialize 9 and 0. For MQA, mean-pool over all heads.
- Minimal Uptraining: Continue pre-training on the same data and schedule for 1 fraction (e.g., 2) of the original steps. This restores quality lost in the direct conversion.
- Preservation of Output and Query Projections: 3 and 4 remain unchanged (Ainslie et al., 2023).
Empirically, mean-pooling outperforms selection or random initialization for grouped KV projections.
4. Empirical Benchmarking and Trade-Off Analysis
Extensive experiments on LLMs (e.g., T5-XXL) and speaker verification systems demonstrate the efficiency-quality frontier:
- Decoder Latency: MQA achieves up to 12× reduction in per-token inference time compared to MHA, with typical BLEU drops ≤ 0.2 (Shazeer, 2019).
- GQA (Intermediate 5): With 6 (for 7), GQA recovers ≈99% of MHA quality but requires only 20–30% of full KV bandwidth. MQA reaches minimum bandwidth but with greater quality loss.
- Pooling Applications: In speaker verification, MQMHA pooling with 8 heads and 9 queries/head yields a 6% relative EER reduction over baseline statistics pooling, and up to 14% when combined with margin-based losses (Zhao et al., 2021).
- Group Size Recommendation: Small $1 < g < h$0 ($1 < g < h$1–$1 < g < h$2) achieves the best trade-off, with diminishing hardware savings and increasing quality degradation as $1 < g < h$3 increases further. AsymGQA (activation-informed assignments) can halve KV cost with <1% downstream accuracy loss (Chen et al., 2024).
- Uptraining Fraction: Increasing uptraining from 5% to 10% further closes the gap between GQA/MQA and full MHA (Ainslie et al., 2023).
| Architecture | KV Heads ($1 < g < h$4) | Relative Inference Time | Quality (Avg. Score) |
|---|---|---|---|
| MHA-XXL | 64 | 2.531 | 47.2 |
| GQA-8-XXL | 8 | 0.514 | 47.1 |
| MQA-XXL | 1 | 0.489 | 46.6 |
5. Implementation and Hardware Considerations
Grouping of query heads reduces key-value projection parameter count, FLOPs, and especially memory-bandwidth requirements for caching during incremental inference. At group size $1 < g < h$5, KV projections and cache usage are reduced by a factor of $1 < g < h$6. For deployment:
- No Kernel Change: Standard GQA can use existing attention kernel implementations with minor bookkeeping for group assignment.
- Cache Access Patterns: Sharing or grouping of K/V greatly reduces buffer reads/writes per decoding step.
- Parallelizability: Group-finding search is embarrassingly parallel across layers, and fine-tuning with low-rank adapters (LoRA) offers efficient convergence (Chen et al., 2024).
In pooling, MQMHA maintains linear $1 < g < h$7 time per utterance and $1 < g < h$8 parameters, supporting efficient embedding extraction for long sequences (Zhao et al., 2021).
6. Hyperparameter Selection and Practical Guidelines
The optimal configuration depends on both required throughput and task-specific accuracy.
- If maximum quality is required: Set $1 < g < h$9 (full MHA).
- If KV bandwidth is limiting: Use smallest 0 consistent with acceptable accuracy loss (e.g., 1–2 for known LLM workloads).
- Mean-pooling is the preferred conversion method for all 3 (Ainslie et al., 2023).
- Pool concatenation: In pooling, always concatenate both weighted mean and standard deviation features.
- Pooling heads and queries: In MQMHA pooling, 4, 5 are empirically robust (Zhao et al., 2021).
On long-input tasks (summarization, QA), the bandwidth reduction is most pronounced.
7. Extensions and Future Research
Ongoing research explores dynamic, activation-informed grouping (AsymGQA), which refines group assignments based on pre-trained activation similarity for better quality at fixed hardware cost (Chen et al., 2024). Other directions include:
- Differentiable Grouping: Using Gumbel-softmax to relax grouping into differentiable parameters for joint training.
- Layer-wise Adaptive Grouping: Choosing 6 per layer to match local capacity demand or identify layer-level KV redundancy.
- Combination with Sparsity/Low Rank: Integrating GQA with sparse or low-rank projection techniques for further memory savings.
- Pool-based Sequence Embedding: Extending MQMHA pooling to additional domains beyond speaker verification.
The MQMHA family—abstracted through the tunable group parameter 7—subsumes both classical MHA and aggressive MQA, enabling a controlled spectrum of quality-hardware trade-offs for Transformer models and related architectures (Ainslie et al., 2023, Chen et al., 2024, Shazeer, 2019, Zhao et al., 2021).