Papers
Topics
Authors
Recent
Search
2000 character limit reached

Multi-Query Multi-Head Attention (MQMHA)

Updated 22 May 2026
  • MQMHA is a family of efficient attention architectures that generalizes classical multi-head attention through flexible sharing of key and value projections.
  • It enables variations like multi-query and grouped-query attention, achieving significant latency reductions while retaining near MHA-quality performance with minimal uptraining.
  • MQMHA is applied in language modeling and speaker verification, reducing memory bandwidth and FLOPs through optimized grouping and conversion strategies.

Multi-Query Multi-Head Attention (MQMHA) is a family of efficient attention architectures that generalizes classical multi-head attention (MHA) by allowing flexible sharing of key and value projections across queries. Originally introduced to address the bandwidth and latency bottlenecks of autoregressive Transformer inference, MQMHA—particularly its grouped-query forms—balances quality retention and hardware efficiency. Architectures in this family include standard MHA (hh query heads and hh separate key-value heads), multi-query attention (MQA; hh query heads, but a single shared key-value head), and grouped-query attention (GQA; hh query heads, gg intermediate key-value heads, $1 < g < h$). MQMHA is also instantiated in pooling modules for sequence embedding, particularly in speaker verification, where multiple learnable queries per head are used to aggregate feature statistics.

1. Formalization and Mathematical Structure

MQMHA defines the number of query heads as hh and the number of key-value heads as gg (1≤g≤h1 \leq g \leq h). For token representations X∈Rn×dX\in\mathbb{R}^{n\times d}, standard MHA involves:

hh0

with hh1, reshaped into hh2 sets of hh3-dimensional vectors per head (hh4). Each head computes: hh5 and the outputs are concatenated and linearly projected.

In MQA, all query heads have distinct projections (hh6) but share key and value projections: hh7

hh8

GQA generalizes MQA by partitioning query heads into hh9 disjoint groups, each sharing its own key-value projection. For group assignment hh0 for head hh1: hh2

hh3

hh4

Special cases: hh5 yields MQA; hh6 recovers standard MHA (Ainslie et al., 2023, Chen et al., 2024).

2. Derivation and Algorithmic Workflow

The essential workflow of grouped-query attention consists of:

  1. Query Projection: For each head hh7, hh8 is computed via hh9.
  2. Key-Value Projection and Grouping: Compute hh0 distinct hh1 as hh2 for hh3. Assign each query head to a group; the mapping hh4 can be fixed (neighbor grouping, even-sized groups) or data-informed (asymmetric, activation-informed grouping).
  3. Attention Computation: Each head hh5 attends to its group’s key and value via the standard scaled dot-product formula.
  4. Aggregation and Output: hh6 outputs are stacked and passed through an output projection.

Activation-informed grouping (e.g., AsymGQA) uses head activation similarity over a calibration set to construct groupings that maximize downstream accuracy, using a stochastic search and brief fine-tuning passes (Chen et al., 2024).

3. Uptraining and Conversion Methods

For LLMs already trained with MHA, conversion to GQA or MQA is achieved by:

  • Mean-Pooling Conversion: For each GQA group, mean-pool the original hh7 and hh8 weights within the group to initialize hh9 and gg0. For MQA, mean-pool over all heads.
  • Minimal Uptraining: Continue pre-training on the same data and schedule for gg1 fraction (e.g., gg2) of the original steps. This restores quality lost in the direct conversion.
  • Preservation of Output and Query Projections: gg3 and gg4 remain unchanged (Ainslie et al., 2023).

Empirically, mean-pooling outperforms selection or random initialization for grouped KV projections.

4. Empirical Benchmarking and Trade-Off Analysis

Extensive experiments on LLMs (e.g., T5-XXL) and speaker verification systems demonstrate the efficiency-quality frontier:

  • Decoder Latency: MQA achieves up to 12× reduction in per-token inference time compared to MHA, with typical BLEU drops ≤ 0.2 (Shazeer, 2019).
  • GQA (Intermediate gg5): With gg6 (for gg7), GQA recovers ≈99% of MHA quality but requires only 20–30% of full KV bandwidth. MQA reaches minimum bandwidth but with greater quality loss.
  • Pooling Applications: In speaker verification, MQMHA pooling with gg8 heads and gg9 queries/head yields a 6% relative EER reduction over baseline statistics pooling, and up to 14% when combined with margin-based losses (Zhao et al., 2021).
  • Group Size Recommendation: Small $1 < g < h$0 ($1 < g < h$1–$1 < g < h$2) achieves the best trade-off, with diminishing hardware savings and increasing quality degradation as $1 < g < h$3 increases further. AsymGQA (activation-informed assignments) can halve KV cost with <1% downstream accuracy loss (Chen et al., 2024).
  • Uptraining Fraction: Increasing uptraining from 5% to 10% further closes the gap between GQA/MQA and full MHA (Ainslie et al., 2023).
Architecture KV Heads ($1 < g < h$4) Relative Inference Time Quality (Avg. Score)
MHA-XXL 64 2.531 47.2
GQA-8-XXL 8 0.514 47.1
MQA-XXL 1 0.489 46.6

5. Implementation and Hardware Considerations

Grouping of query heads reduces key-value projection parameter count, FLOPs, and especially memory-bandwidth requirements for caching during incremental inference. At group size $1 < g < h$5, KV projections and cache usage are reduced by a factor of $1 < g < h$6. For deployment:

  • No Kernel Change: Standard GQA can use existing attention kernel implementations with minor bookkeeping for group assignment.
  • Cache Access Patterns: Sharing or grouping of K/V greatly reduces buffer reads/writes per decoding step.
  • Parallelizability: Group-finding search is embarrassingly parallel across layers, and fine-tuning with low-rank adapters (LoRA) offers efficient convergence (Chen et al., 2024).

In pooling, MQMHA maintains linear $1 < g < h$7 time per utterance and $1 < g < h$8 parameters, supporting efficient embedding extraction for long sequences (Zhao et al., 2021).

6. Hyperparameter Selection and Practical Guidelines

The optimal configuration depends on both required throughput and task-specific accuracy.

  • If maximum quality is required: Set $1 < g < h$9 (full MHA).
  • If KV bandwidth is limiting: Use smallest hh0 consistent with acceptable accuracy loss (e.g., hh1–hh2 for known LLM workloads).
  • Mean-pooling is the preferred conversion method for all hh3 (Ainslie et al., 2023).
  • Pool concatenation: In pooling, always concatenate both weighted mean and standard deviation features.
  • Pooling heads and queries: In MQMHA pooling, hh4, hh5 are empirically robust (Zhao et al., 2021).

On long-input tasks (summarization, QA), the bandwidth reduction is most pronounced.

7. Extensions and Future Research

Ongoing research explores dynamic, activation-informed grouping (AsymGQA), which refines group assignments based on pre-trained activation similarity for better quality at fixed hardware cost (Chen et al., 2024). Other directions include:

  • Differentiable Grouping: Using Gumbel-softmax to relax grouping into differentiable parameters for joint training.
  • Layer-wise Adaptive Grouping: Choosing hh6 per layer to match local capacity demand or identify layer-level KV redundancy.
  • Combination with Sparsity/Low Rank: Integrating GQA with sparse or low-rank projection techniques for further memory savings.
  • Pool-based Sequence Embedding: Extending MQMHA pooling to additional domains beyond speaker verification.

The MQMHA family—abstracted through the tunable group parameter hh7—subsumes both classical MHA and aggressive MQA, enabling a controlled spectrum of quality-hardware trade-offs for Transformer models and related architectures (Ainslie et al., 2023, Chen et al., 2024, Shazeer, 2019, Zhao et al., 2021).

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Multi-Query Multi-Head Attention (MQMHA).