MTLA: Multi-head Temporal Latent Attention
- The paper introduces MTLA, a self-attention variant that compresses the key-value cache temporally to reduce memory usage and speed up inference.
- MTLA leverages a hyper-network for dynamic temporal merging and employs a stride-aware causal mask to ensure consistent training and inference.
- Empirical results demonstrate up to 8× GPU memory reduction and 5× speedup in applications such as speech translation, recognition, and text summarisation.
Multi-head Temporal Latent Attention (MTLA) is a self-attention variant designed to address the inference-time memory and computational bottlenecks of Transformer architectures. MTLA advances the compression paradigm introduced by Multi-Head Latent Attention (MLA) by reducing the Key-Value (KV) cache size along the temporal axis, leading to substantial improvements in inference speed and GPU memory usage without significant degradation in model quality. MTLA features dynamic temporal merging via a hyper-network and utilizes a stride-aware causal mask to reconcile the spatial-temporal compression with parallel training and consistent inference behaviour. Empirical evaluations demonstrate MTLA’s efficacy in tasks such as speech translation, speech recognition, spoken language understanding, and text summarisation (2505.13544).
1. Architectural Foundation and Motivation
MTLA is conceptually rooted in the Transformer’s multi-head attention (MHA) mechanism. MHA maintains per-head Key and Value representations for each time-step, resulting in a KV cache of size , where is sequence length, number of heads, and head dimension. MLA introduced low-rank latent compression by mapping input to latent () and reconstructing via learned projections.
MTLA takes further steps by compressing temporally: consecutive blocks of latent vectors are merged using a hyper-network, producing cache entries and achieving cache complexity . This design directly targets the autoregressive inference KV cache growth, decreasing both temporal storage and per-step attention cost. The rationale is that adjacent tokens in long sequences (e.g., speech) carry redundant information and merging preserves key semantics with drastically reduced resource consumption.
2. Latent Space Factorization
MTLA first maps to a latent space using
At inference, is stored as the compressed KV cache. For attention computation, latent is up-projected:
More generally,
Common parameter choices include and . This compression reduces KV cache storage by a factor of relative to MHA.
3. Hyper-network Temporal Merging
The core distinguishing mechanism is dynamic temporal merging. MTLA deploys a compact MLP hyper-network that generates per-time-step merge weights based on latent input and positional embeddings. Specifically, for block , the merge weight for is
where is the sigmoid, and projects elementwise products to a scalar. During inference, each merged cache entry is updated incrementally:
1 2 3 4 5 6 7 8 9 |
for i in 1..T: j = ceil(i/s) α = W_a c_i + b_a β = W_b pe_j + b_b w_i = sigmoid((α ⊙ β) · v) if i mod s == 1: ĉ_j ← w_i·c_i else: ĉ_j ← ĉ_j + w_i·c_i |
4. Stride-aware Causal Masking
Temporal compression introduces cache positions that only exist at block boundaries. Standard causal masks ( if , else ) are incompatible with MTLA’s blockwise cache. The stride-aware mask constrains query-to-key connectivity:
In practice:
1 2 3 |
for m in 1..T: for n in 1..t: if n*s ≤ m: mask[m,n]=0 else mask[m,n]=–∞ |
5. Training and Inference Procedure
MTLA is trained end-to-end with standard cross-entropy (e.g., for translation, summarisation) or CTC+CE loss (for ASR). All components (query/key projections, low-rank mapping, hyper-network weights, stride-aware mask logic) are optimized jointly. Hyper-network gradients propagate through the attention calculation, requiring no auxiliary losses. During inference, cache updates and merging are performed incrementally on receipt of new tokens, directly paralleling the decomposed, causal structure imposed by blockwise compression. In parallel/batched training, all temporal merges and masking are executed in vectorized fashion, simulating the inference attention pattern and cache visibility.
6. Computational Complexity and Resource Efficiency
A direct big-O comparison reveals substantial efficiency gains:
- MHA: Time ; Memory
- MLA: Time (smaller constants); Memory
- MTLA (stride ): Time per step ; Memory
In speech translation (English–German, , , , layers):
- MHA KV cache: floats
- MLA (r=): floats
- MTLA (s=2): $1152$ floats; less than MHA
Empirical resource usage and quality for MuST-C En–De (BLEU, time, GPU memory):
| Model | Quality (BLEU) | Time (s) | Speedup | GPU (MiB) | Mem Factor |
|---|---|---|---|---|---|
| MHA | 23.18 | 281.3 | 1.00× | 18646 | 1.00× |
| MLA | 22.97 | 97.0 | 2.90× | 5065 | 3.68× |
| MTLA (s=2) | 23.28 | 65.6 | 4.29× | 2835 | 6.58× |
| MTLA (s=3) | 23.25 | 52.7 | 5.34× | 2251 | 8.28× |
| MTLA (s=4) | 23.05 | 48.7 | 5.78× | 1921 | 9.71× |
7. Empirical Evaluation Across Modalities
MTLA demonstrates competitive or superior task performance to baseline MHA and MLA across diverse tasks:
- Speech Translation (MuST-C En–De): BLEU parity or improvement, – speedup, – GPU memory reduction.
- Text Summarisation (XSum): ROUGE-1/2/L vs MHA $28.83/9.67/23.33$ with speedup and lower memory.
- Speech Recognition (AMI ASR): WER (MHA ), faster, memory reduction.
- Spoken-LU (SLURP IC): Accuracy (MHA ), faster, memory reduction.
This suggests that temporal latent compression, as instantiated by MTLA, maintains semantic representation effectiveness while providing substantial engineering and resource efficiencies. A plausible implication is that further exploration of dynamic merge strategies or hierarchy-aware masking may yield additional benefits for long-context or low-latency applications.