Papers
Topics
Authors
Recent
Search
2000 character limit reached

MTLA: Multi-head Temporal Latent Attention

Updated 26 January 2026
  • The paper introduces MTLA, a self-attention variant that compresses the key-value cache temporally to reduce memory usage and speed up inference.
  • MTLA leverages a hyper-network for dynamic temporal merging and employs a stride-aware causal mask to ensure consistent training and inference.
  • Empirical results demonstrate up to 8× GPU memory reduction and 5× speedup in applications such as speech translation, recognition, and text summarisation.

Multi-head Temporal Latent Attention (MTLA) is a self-attention variant designed to address the inference-time memory and computational bottlenecks of Transformer architectures. MTLA advances the compression paradigm introduced by Multi-Head Latent Attention (MLA) by reducing the Key-Value (KV) cache size along the temporal axis, leading to substantial improvements in inference speed and GPU memory usage without significant degradation in model quality. MTLA features dynamic temporal merging via a hyper-network and utilizes a stride-aware causal mask to reconcile the spatial-temporal compression with parallel training and consistent inference behaviour. Empirical evaluations demonstrate MTLA’s efficacy in tasks such as speech translation, speech recognition, spoken language understanding, and text summarisation (2505.13544).

1. Architectural Foundation and Motivation

MTLA is conceptually rooted in the Transformer’s multi-head attention (MHA) mechanism. MHA maintains per-head Key and Value representations for each time-step, resulting in a KV cache of size O(Tnhdh)O(T \cdot n_h \cdot d_h), where TT is sequence length, nhn_h number of heads, and dhd_h head dimension. MLA introduced low-rank latent compression by mapping input XRT×dX \in \mathbb{R}^{T \times d} to latent CRT×rC \in \mathbb{R}^{T \times r} (rnhdhr \ll n_h d_h) and reconstructing K,VK, V via learned projections.

MTLA takes further steps by compressing CC temporally: consecutive blocks of ss latent vectors are merged using a hyper-network, producing T/s\lceil T/s \rceil cache entries and achieving cache complexity O((T/s)r)O((T/s) \cdot r). This design directly targets the autoregressive inference KV cache growth, decreasing both temporal storage and per-step attention cost. The rationale is that adjacent tokens in long sequences (e.g., speech) carry redundant information and merging preserves key semantics with drastically reduced resource consumption.

2. Latent Space Factorization

MTLA first maps XX to a latent space using

C=XWr,WrRd×r,rnhdh.C = X W_r, \qquad W_r \in \mathbb{R}^{d \times r}, \qquad r \ll n_h d_h.

At inference, CC is stored as the compressed KV cache. For attention computation, latent CC is up-projected:

KCWK,VCWV,WK,WVRr×nhdh.K \approx C W_K, \qquad V \approx C W_V, \qquad W_K, W_V \in \mathbb{R}^{r \times n_h d_h}.

More generally,

KQKLKT,QK=C,LKT=WK VQVLVT,QV=C,LVT=WVK \approx Q_K L_K^T, \qquad Q_K = C, \qquad L_K^T = W_K \ V \approx Q_V L_V^T, \qquad Q_V = C, \qquad L_V^T = W_V

Common parameter choices include r=4dhr = 4d_h and nhdh=dn_h d_h = d. This compression reduces KV cache storage by a factor of r/(nhdh)r/(n_h d_h) relative to MHA.

3. Hyper-network Temporal Merging

The core distinguishing mechanism is dynamic temporal merging. MTLA deploys a compact MLP hyper-network that generates per-time-step merge weights based on latent input and positional embeddings. Specifically, for block j=i/sj = \lceil i/s \rceil, the merge weight for cic_i is

αi=Waci+ba βj=Wbpej+bb wi=σ((αiβj)v)\begin{align*} \alpha_i &= W_a c_i + b_a \ \beta_j &= W_b pe_j + b_b \ w_i &= \sigma((\alpha_i \odot \beta_j) \cdot v) \end{align*}

where σ\sigma is the sigmoid, and vv projects elementwise products to a scalar. During inference, each merged cache entry c^j\hat{c}_j is updated incrementally:

1
2
3
4
5
6
7
8
9
for i in 1..T:
    j = ceil(i/s)
    α = W_a c_i + b_a
    β = W_b pe_j + b_b
    w_i = sigmoid((α  β) · v)
    if i mod s == 1:
        ĉ_j  w_i·c_i
    else:
        ĉ_j  ĉ_j + w_i·c_i
During parallel training, all weights wiw_i are generated in batch and combined with chunk masking, ensuring each latent is merged strictly within its temporal block.

4. Stride-aware Causal Masking

Temporal compression introduces cache positions that only exist at block boundaries. Standard causal masks (M[i,n]=0M[i, n]=0 if nin \leq i, else -\infty) are incompatible with MTLA’s blockwise cache. The stride-aware mask constrains query-to-key connectivity:

Mstride[m,n]={0if nm,nsms,(nmods)=0 otherwiseM_{\mathrm{stride}}[m, n] = \begin{cases} 0 & \text{if } n \leq m,\, n \cdot s \leq m \cdot s,\, (n \mod s) = 0 \ -\infty & \text{otherwise} \end{cases}

In practice:

1
2
3
for m in 1..T:
    for n in 1..t:
        if n*s  m: mask[m,n]=0 else mask[m,n]=
This design ensures that queries only attend to visible, fully (or partially) merged block vectors, maintaining causal consistency during both training and incremental inference.

5. Training and Inference Procedure

MTLA is trained end-to-end with standard cross-entropy (e.g., for translation, summarisation) or CTC+CE loss (for ASR). All components (query/key projections, low-rank mapping, hyper-network weights, stride-aware mask logic) are optimized jointly. Hyper-network gradients propagate through the attention calculation, requiring no auxiliary losses. During inference, cache updates and merging are performed incrementally on receipt of new tokens, directly paralleling the decomposed, causal structure imposed by blockwise compression. In parallel/batched training, all temporal merges and masking are executed in vectorized fashion, simulating the inference attention pattern and cache visibility.

6. Computational Complexity and Resource Efficiency

A direct big-O comparison reveals substantial efficiency gains:

  • MHA: Time O(T2dhnh)O(T^2 d_h n_h); Memory O(Tnhdh)O(T n_h d_h)
  • MLA: Time O(T2dh)O(T^2 d_h) (smaller constants); Memory O(Tr)O(T r)
  • MTLA (stride ss): Time per step O((T/s)dhnh)O((T/s) d_h n_h); Memory O((T/s)r)O((T/s) r)

In speech translation (English–German, s=2s=2, dh=64d_h=64, nh=8n_h=8, l=9l=9 layers):

  • MHA KV cache: 2×64×8×9=92162 \times 64 \times 8 \times 9 = 9216 floats
  • MLA (r=4dh4d_h): 4×64×9=23044 \times 64 \times 9 = 2304 floats
  • MTLA (s=2): $1152$ floats; 8×8\times less than MHA

Empirical resource usage and quality for MuST-C En–De (BLEU, time, GPU memory):

Model Quality (BLEU) Time (s) Speedup GPU (MiB) Mem Factor
MHA 23.18 281.3 1.00× 18646 1.00×
MLA 22.97 97.0 2.90× 5065 3.68×
MTLA (s=2) 23.28 65.6 4.29× 2835 6.58×
MTLA (s=3) 23.25 52.7 5.34× 2251 8.28×
MTLA (s=4) 23.05 48.7 5.78× 1921 9.71×

7. Empirical Evaluation Across Modalities

MTLA demonstrates competitive or superior task performance to baseline MHA and MLA across diverse tasks:

  • Speech Translation (MuST-C En–De): BLEU parity or improvement, 4.3×4.3\times5.8×5.8\times speedup, 6.6×6.6\times9.7×9.7\times GPU memory reduction.
  • Text Summarisation (XSum): ROUGE-1/2/L =29.14/9.79/23.60=29.14/9.79/23.60 vs MHA $28.83/9.67/23.33$ with 3.35×3.35\times speedup and 7.34×7.34\times lower memory.
  • Speech Recognition (AMI ASR): WER 12.66%12.66\% (MHA 12.98%12.98\%), 3.75×3.75\times faster, 7.4×7.4\times memory reduction.
  • Spoken-LU (SLURP IC): Accuracy 86.80%86.80\% (MHA 86.83%86.83\%), 2.53×2.53\times faster, 7.01×7.01\times memory reduction.

This suggests that temporal latent compression, as instantiated by MTLA, maintains semantic representation effectiveness while providing substantial engineering and resource efficiencies. A plausible implication is that further exploration of dynamic merge strategies or hierarchy-aware masking may yield additional benefits for long-context or low-latency applications.

(2505.13544)

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Multi-head Temporal Latent Attention (MTLA).