Tensor Product Attention Is All You Need (2501.06425v4)
Abstract: Scaling LLMs to handle longer input sequences typically necessitates large key-value (KV) caches, resulting in substantial memory overhead during inference. In this paper, we propose Tensor Product Attention (TPA), a novel attention mechanism that uses tensor decompositions to represent queries, keys, and values compactly, substantially shrinking the KV cache size at inference time. By factorizing these representations into contextual low-rank components and seamlessly integrating with Rotary Position Embedding (RoPE), TPA achieves improved model quality alongside memory efficiency. Based on TPA, we introduce the Tensor Product Attention Transformer,(T6), a new model architecture for sequence modeling. Through extensive empirical evaluation on LLMing tasks, we demonstrate that T6 surpasses or matches the performance of standard Transformer baselines, including Multi-Head Attention (MHA), Multi-Query Attention (MQA), Grouped-Query Attention (GQA), and Multi-Head Latent Attention (MLA) across various metrics, including perplexity and a range of established evaluation benchmarks. Notably, TPA's memory efficiency and computational efficiency at the decoding stage enable processing longer sequences under fixed resource constraints, addressing a critical scalability challenge in modern LLMs. The code is available at https://github.com/tensorgi/T6.
Summary
- The paper introduces Tensor Product Attention (TPA), a novel mechanism using tensor decomposition to factorize Q, K, and V, achieving over 10x reduction in LLM KV cache size during inference.
- TPA enables the T6 sequence modeling architecture and integrates seamlessly with RoPE, allowing direct replacement of MHA layers in existing LLMs like LLaMA and Gemma.
- Experiments show TPA converges faster and achieves lower loss than MHA/GQA baselines, demonstrating improved performance and scalability for LLMs.
The paper introduces Tensor Product Attention (TPA), a novel attention mechanism designed to mitigate the memory overhead associated with key-value (KV) caches in LLMs during inference. The core idea involves factorizing queries ($\Qb$), keys ($\Kb$), and values ($\Vb$) using tensor decompositions, thereby enabling a compact representation of these entities and a significant reduction in KV cache size. The authors introduce the T\,(T6) model architecture based on TPA for sequence modeling.
TPA employs contextual low-rank factorization, where queries, keys, and values are decomposed into contextual low-rank components. This dynamic factorization of activations, as opposed to static weights, constructs low-rank representations that substantially reduce KV cache memory usage. TPA is natively compatible with rotary positional embeddings (RoPE), allowing for a direct replacement of multi-head attention (MHA) layers in existing LLM architectures like LLaMA and Gemma.
The authors summarize their primary contributions as follows:
- Proposing TPA, a mechanism that factorizes $\Qb$, $\Kb$, and $\Vb$ activations using contextual tensor-decompositions to achieve a 10× or more reduction in inference-time KV cache size relative to standard attention mechanism with improved performance compared to previous methods such as MHA, MQA, GQA, and MLA.
- Proposing T\,(T6), a new TPA-based model architecture for sequence modeling. On LLMing experiments, T6\ consistently improves validation perplexity and downstream evaluation performance with reduced KV cache size.
- Showing that TPA integrates seamlessly with RoPE, facilitating easy adoption in popular foundation model architectures such as LLaMA and Gemma.
The paper also provides background on scaled dot-product attention, MHA, multi-query attention (MQA), grouped-query attention (GQA), RoPE, and multi-head latent attention (MLA). Notations used include bold uppercase letters for matrices, bold lowercase for vectors, and italic uppercase for learnable parameter matrices. The tensor product of two vectors $\ab\in\RR^m, \bbb\in \RR^n$ is defined as $\ab\otimes\bbb=\Cb\in \RR^{m\times n}$, with Cij=aibj, and the vectorization of a matrix $\Cb\in \RR^{m\times n}$ is defined as $\text{vec}(\Cb)=\db\in\RR^{m n}$, with di⋅n+j=Cij. Scaled dot-product attention is given by:
$\operatorname{Attention}(\Qb, \Kb, \Vb) = \operatorname{Softmax}\Bigl(\tfrac{\Qb \Kb^{\top}}{\sqrt{d_k}\Bigr)\,\Vb,$
where $\Qb, \Kb, \Vb \in \RR^{n \times d_k}$.
The MHA computes each head i for token embedding $\xb_t \in \mathbb{R}^{d_{\text{model}}$ as:
$\Qb_{t,i} = (\bW_i^Q)^{\top} \,\xb_t \in\mathbb{R}^{d_h}, \quad \Kb_{t,i} = (\bW_i^K)^{\top} \,\xb_t \in\mathbb{R}^{d_h}, \quad \Vb_{t,i} = (\bW_i^V)^{\top} \,\xb_t \in\mathbb{R}^{d_h},$
where $\bW_i^Q, \bW_i^K, \bW_i^V \in \mathbb{R}^{d_{\text{model} \times d_h}$ are learnable projection matrices.
MQA shares keys and values across heads, expressed as:
$\Qb_{i} = \Xb\bW^Q_{i}, \quad \Kb_{\text{shared} = \Xb\bW^K_{\text{shared}, \quad \Vb_{\text{shared} = \Xb\bW^V_{\text{shared},$
with $\bW^Q_{i} \in \mathbb{R}^{d_{\text{model} \times d_k}, \quad \bW^K_{\text{shared}, \bW^V_{\text{shared} \;\in\; \mathbb{R}^{\,d_{\text{model} \times d_k}$.
GQA partitions the h total heads into G groups, each with a single set of keys and values:
$\Kb_{g(i)} = \Xb\,\bW^K_{g(i)}, \quad \Vb_{g(i)} = \Xb\,\bW^V_{g(i)}, \quad \Qb_{i} = \Xb\,\bW^Q_{i},$
where $\bW^K_{g}, \bW^V_{g} \in \mathbb{R}^{d_{\text{model} \times d_k}$ and $\bW^Q_{i} \in \mathbb{R}^{\,d_{\text{model} \times d_k}$.
RoPE uses a rotation operator $\Tb_t \in \RR^{d_h \times d_h}$ corresponding to the t-th position, and $\operatorname{RoPE}\left(\Qb_t\right) \triangleq \Qb_t\Tb_t$, where $\Qb_t \in \RR^{h \times d_h}$.
MLA introduces a low-rank compression of the keys and values to reduce the Key-Value (KV) caching cost at inference.
$\mathbf{C}<sup>{KV}</sup> =\mathbf{X}\bW<sup>{DKV},</sup> \quad (\bW<sup>{DKV}</sup> \in \mathbb{R}<sup>{\,</sup> d_{\text{model}\times d_c}),\ \operatorname{Concat}\bigl(\mathbf{K}<em>{1}<sup>{C},\mathbf{K}</sup></em>{2}<sup>{C},\ldots,\mathbf{K}_{h}<sup>{C}\bigr)</sup></sup> =\mathbf{K}<sup>{C}</sup> =\mathbf{C}<sup>{KV}\bW<sup>{UK},</sup></sup> \quad (\bW<sup>{UK}</sup> \in \mathbb{R}<sup>{d_c\times</sup> d_h h}).$In **TPA**, the hidden-state vector$\xb_t \in \mathbb{R}<sup>{d_{\text{model}$for thet-th token in a sequence of lengthT. **TPA** factorizes each$\Qb_{t},</sup> \Kb_{t}, \Vb_{t}$into a sum of tensor products:$\Qb_{t} = \frac{1}{R_Q} \sum_{r=1}<sup>{R_Q}</sup> \ab<sup>{Q}_{r}(\xb_t)</sup> \;\otimes\; \bbb<sup>{Q}_{r}(\xb_t),</sup> \quad \Kb_{t} = \frac{1}{R_K} \sum_{r=1}<sup>{R_K}</sup> \ab<sup>{K}_{r}(\xb_t)</sup> \;\otimes\; \bbb<sup>{K}_{r}(\xb_t),</sup> \quad \Vb_{t} = \frac{1}{R_V} \sum_{r=1}<sup>{R_V}</sup> \ab<sup>{V}_{r}(\xb_t)</sup> \;\otimes\; \bbb<sup>{V}<em>{r}(\xb_t),$where$\ab<sup>{Q}</sup></em>{r}(\xb_t)</sup> \in \mathbb{R}<sup>h,</sup> \bbb<sup>{Q}_{r}(\xb_t)</sup> \in \mathbb{R}<sup>{d_h},</sup> \ab<sup>{K}_{r}(\xb_t)</sup> \in \mathbb{R}<sup>h,</sup> \bbb<sup>{K}_{r}(\xb_t)</sup> \in \mathbb{R}<sup>{d_h},</sup> \ab<sup>{V}_{r}(\xb_t)</sup> \in \mathbb{R}<sup>h$, and$\bbb<sup>{V}_{r}(\xb_t)</sup></sup> \in \mathbb{R}<sup>{d_h}$.
The latent factor maps are given by:
$\ab<sup>{Q}_{r}(\xb_t)</sup> = \bW<sup>{a<sup>Q}_{r}\,\xb_t</sup></sup> \in \mathbb{R}<sup>h,</sup> \quad \bbb<sup>Q_{r}(\xb_t)</sup> = \bW<sup>{b<sup>Q}_{r}\,\xb_t</sup></sup> \in \mathbb{R}<sup>{d_h}.$After$\Qb,\Kb,\Vb$are factorized, multi-head attention proceeds as in standard Transformers, with:$head_i</sup> = \operatorname{Softmax}\Bigl( \tfrac{1}{\sqrt{d_h} \,\Qb_{i} \, (\Kb_{i})<sup>\top</sup> \Bigr) \;\Vb_{i},$where$\Qb_{i}, \Kb_{i}, \Vb_{i} \in \mathbb{R}<sup>{T</sup> \times d_h}$ are the slices along the head dimension.
For RoPE integration, the paper suggests pre-rotating the token-dimension factors:
$\tilde\Bb_K(\xb_t) \;\longleftarrow\;</p> <h1 class='paper-heading' id='operatorname-rope-_t-bigl-bb_k-xb_t-bigr-40-qb_t-41-operatorname-rope-qb-_t'>\operatorname{RoPE}_t\bigl(\Bb_K(\xb_t)\bigr).$A key theorem states that RoPE's relative translational property is preserved in **TPA**. If$\Qb_t$is factorized by **TPA**, then$\operatorname{RoPE}({\Qb}_t)</h1> <p>\frac{1}{R_Q} \Ab_{Q}(\xb_t)<sup>\top</sup> \,\widetilde{\Bb}<em>{Q}(\xb_t),$where$\widetilde{\Bb}</em>{Q}(\xb_t) = \operatorname{RoPE}<em>t\bigl(\Bb</em>{Q}(\xb_t)\bigr)$.
The memory cost per token in TPA is (RK+RV)(h+dh), which can be significantly lower than the standard caching cost of 2hdh.
The paper demonstrates how MHA, MQA, and GQA can be unified as non-contextual variants of TPA. Specifically, standard MHA can be viewed as a specific instance of TPA in which: 1) the rank is set equal to the number of heads; 2) the head dimension factor is non-contextual; 3) the token dimension factor is a linear function of $\xb_t$.
In MQA, all heads share a single set of keys/values, corresponding to RK=RV=1 along the head dimension, while GQA partitions h heads into G groups, each sharing keys/values within that group.
The T\,(T6) architecture, which utilizes TPA in place of standard MHA or GQA, is also detailed. The feed-forward network (FFN) adopts a SwiGLU layer, and RoPE is applied to the $\Qb$ and $\Kb$.
Experiments were conducted on the FineWeb-Edu 100B dataset, comparing T6 against the baseline Llama architecture with SwiGLU activation and RoPE embeddings, as well as Llama variants that replace MHA with MQA, GQA, or MLA. Models were trained at small (124M parameters), medium (353M), and large (773M) scales using the AdamW optimizer.
Results indicate that TPA and its variant TPA-KVonly converge as fast as or faster than the baselines while achieving visibly lower final losses. In downstream evaluations on standard benchmarks, TPA generally ties or outperforms all competing methods.
The paper concludes that TPA offers a flexible, memory-efficient alternative to standard multi-head attention, advancing the scalability of modern LLMs.
Follow-up Questions
We haven't generated follow-up questions for this paper yet.
Related Papers
- Landmark Attention: Random-Access Infinite Context Length for Transformers (2023)
- Lean Attention: Hardware-Aware Scalable Attention Mechanism for the Decode-Phase of Transformers (2024)
- Reducing Transformer Key-Value Cache Size with Cross-Layer Attention (2024)
- Taipan: Efficient and Expressive State Space Language Models with Selective Attention (2024)
- Attamba: Attending To Multi-Token States (2024)
Tweets
YouTube
HackerNews
- Tensor Product Attention Is All You Need (160 points, 103 comments)
- Tensor Product Attention Is All You Need (2 points, 1 comment)
- [R] Tensor Product Attention is All You Need (0 points, 3 comments)