Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
120 tokens/sec
GPT-4o
10 tokens/sec
Gemini 2.5 Pro Pro
42 tokens/sec
o3 Pro
5 tokens/sec
GPT-4.1 Pro
3 tokens/sec
DeepSeek R1 via Azure Pro
51 tokens/sec
2000 character limit reached

Tensor Product Attention Is All You Need (2501.06425v4)

Published 11 Jan 2025 in cs.CL, cs.AI, and cs.LG

Abstract: Scaling LLMs to handle longer input sequences typically necessitates large key-value (KV) caches, resulting in substantial memory overhead during inference. In this paper, we propose Tensor Product Attention (TPA), a novel attention mechanism that uses tensor decompositions to represent queries, keys, and values compactly, substantially shrinking the KV cache size at inference time. By factorizing these representations into contextual low-rank components and seamlessly integrating with Rotary Position Embedding (RoPE), TPA achieves improved model quality alongside memory efficiency. Based on TPA, we introduce the Tensor Product Attention Transformer,(T6), a new model architecture for sequence modeling. Through extensive empirical evaluation on LLMing tasks, we demonstrate that T6 surpasses or matches the performance of standard Transformer baselines, including Multi-Head Attention (MHA), Multi-Query Attention (MQA), Grouped-Query Attention (GQA), and Multi-Head Latent Attention (MLA) across various metrics, including perplexity and a range of established evaluation benchmarks. Notably, TPA's memory efficiency and computational efficiency at the decoding stage enable processing longer sequences under fixed resource constraints, addressing a critical scalability challenge in modern LLMs. The code is available at https://github.com/tensorgi/T6.

Summary

  • The paper introduces Tensor Product Attention (TPA), a novel mechanism using tensor decomposition to factorize Q, K, and V, achieving over 10x reduction in LLM KV cache size during inference.
  • TPA enables the T6 sequence modeling architecture and integrates seamlessly with RoPE, allowing direct replacement of MHA layers in existing LLMs like LLaMA and Gemma.
  • Experiments show TPA converges faster and achieves lower loss than MHA/GQA baselines, demonstrating improved performance and scalability for LLMs.

The paper introduces Tensor Product Attention (TPA), a novel attention mechanism designed to mitigate the memory overhead associated with key-value (KV) caches in LLMs during inference. The core idea involves factorizing queries ($\Qb$), keys ($\Kb$), and values ($\Vb$) using tensor decompositions, thereby enabling a compact representation of these entities and a significant reduction in KV cache size. The authors introduce the T\,(T6) model architecture based on TPA for sequence modeling.

TPA employs contextual low-rank factorization, where queries, keys, and values are decomposed into contextual low-rank components. This dynamic factorization of activations, as opposed to static weights, constructs low-rank representations that substantially reduce KV cache memory usage. TPA is natively compatible with rotary positional embeddings (RoPE), allowing for a direct replacement of multi-head attention (MHA) layers in existing LLM architectures like LLaMA and Gemma.

The authors summarize their primary contributions as follows:

  • Proposing TPA, a mechanism that factorizes $\Qb$, $\Kb$, and $\Vb$ activations using contextual tensor-decompositions to achieve a 10×10\times or more reduction in inference-time KV cache size relative to standard attention mechanism with improved performance compared to previous methods such as MHA, MQA, GQA, and MLA.
  • Proposing T\,(T6), a new TPA-based model architecture for sequence modeling. On LLMing experiments, T6\ consistently improves validation perplexity and downstream evaluation performance with reduced KV cache size.
  • Showing that TPA integrates seamlessly with RoPE, facilitating easy adoption in popular foundation model architectures such as LLaMA and Gemma.

The paper also provides background on scaled dot-product attention, MHA, multi-query attention (MQA), grouped-query attention (GQA), RoPE, and multi-head latent attention (MLA). Notations used include bold uppercase letters for matrices, bold lowercase for vectors, and italic uppercase for learnable parameter matrices. The tensor product of two vectors $\ab\in\RR^m, \bbb\in \RR^n$ is defined as $\ab\otimes\bbb=\Cb\in \RR^{m\times n}$, with Cij=aibjC_{ij}=a_ib_j, and the vectorization of a matrix $\Cb\in \RR^{m\times n}$ is defined as $\text{vec}(\Cb)=\db\in\RR^{m n}$, with din+j=Cijd_{i\cdot n+j}=C_{ij}. Scaled dot-product attention is given by:

$\operatorname{Attention}(\Qb, \Kb, \Vb) = \operatorname{Softmax}\Bigl(\tfrac{\Qb \Kb^{\top}}{\sqrt{d_k}\Bigr)\,\Vb,$

where $\Qb, \Kb, \Vb \in \RR^{n \times d_k}$.

The MHA computes each head ii for token embedding $\xb_t \in \mathbb{R}^{d_{\text{model}}$ as:

$\Qb_{t,i} = (\bW_i^Q)^{\top} \,\xb_t \in\mathbb{R}^{d_h}, \quad \Kb_{t,i} = (\bW_i^K)^{\top} \,\xb_t \in\mathbb{R}^{d_h}, \quad \Vb_{t,i} = (\bW_i^V)^{\top} \,\xb_t \in\mathbb{R}^{d_h},$

where $\bW_i^Q, \bW_i^K, \bW_i^V \in \mathbb{R}^{d_{\text{model} \times d_h}$ are learnable projection matrices.

MQA shares keys and values across heads, expressed as:

$\Qb_{i} = \Xb\bW^Q_{i}, \quad \Kb_{\text{shared} = \Xb\bW^K_{\text{shared}, \quad \Vb_{\text{shared} = \Xb\bW^V_{\text{shared},$

with $\bW^Q_{i} \in \mathbb{R}^{d_{\text{model} \times d_k}, \quad \bW^K_{\text{shared}, \bW^V_{\text{shared} \;\in\; \mathbb{R}^{\,d_{\text{model} \times d_k}$.

GQA partitions the hh total heads into GG groups, each with a single set of keys and values:

$\Kb_{g(i)} = \Xb\,\bW^K_{g(i)}, \quad \Vb_{g(i)} = \Xb\,\bW^V_{g(i)}, \quad \Qb_{i} = \Xb\,\bW^Q_{i},$

where $\bW^K_{g}, \bW^V_{g} \in \mathbb{R}^{d_{\text{model} \times d_k}$ and $\bW^Q_{i} \in \mathbb{R}^{\,d_{\text{model} \times d_k}$.

RoPE uses a rotation operator $\Tb_t \in \RR^{d_h \times d_h}$ corresponding to the tt-th position, and $\operatorname{RoPE}\left(\Qb_t\right) \triangleq \Qb_t\Tb_t$, where $\Qb_t \in \RR^{h \times d_h}$.

MLA introduces a low-rank compression of the keys and values to reduce the Key-Value (KV) caching cost at inference.

$\mathbf{C}<sup>{KV}</sup> =\mathbf{X}\bW<sup>{DKV},</sup> \quad (\bW<sup>{DKV}</sup> \in \mathbb{R}<sup>{\,</sup> d_{\text{model}\times d_c}),\ \operatorname{Concat}\bigl(\mathbf{K}<em>{1}<sup>{C},\mathbf{K}</sup></em>{2}<sup>{C},\ldots,\mathbf{K}_{h}<sup>{C}\bigr)</sup></sup> =\mathbf{K}<sup>{C}</sup> =\mathbf{C}<sup>{KV}\bW<sup>{UK},</sup></sup> \quad (\bW<sup>{UK}</sup> \in \mathbb{R}<sup>{d_c\times</sup> d_h h}).$In **TPA**, the hidden-state vector$\xb_t \in \mathbb{R}<sup>{d_{\text{model}$for thett-th token in a sequence of lengthTT. **TPA** factorizes each$\Qb_{t},</sup> \Kb_{t}, \Vb_{t}$into a sum of tensor products:$\Qb_{t} = \frac{1}{R_Q} \sum_{r=1}<sup>{R_Q}</sup> \ab<sup>{Q}_{r}(\xb_t)</sup> \;\otimes\; \bbb<sup>{Q}_{r}(\xb_t),</sup> \quad \Kb_{t} = \frac{1}{R_K} \sum_{r=1}<sup>{R_K}</sup> \ab<sup>{K}_{r}(\xb_t)</sup> \;\otimes\; \bbb<sup>{K}_{r}(\xb_t),</sup> \quad \Vb_{t} = \frac{1}{R_V} \sum_{r=1}<sup>{R_V}</sup> \ab<sup>{V}_{r}(\xb_t)</sup> \;\otimes\; \bbb<sup>{V}<em>{r}(\xb_t),$where$\ab<sup>{Q}</sup></em>{r}(\xb_t)</sup> \in \mathbb{R}<sup>h,</sup> \bbb<sup>{Q}_{r}(\xb_t)</sup> \in \mathbb{R}<sup>{d_h},</sup> \ab<sup>{K}_{r}(\xb_t)</sup> \in \mathbb{R}<sup>h,</sup> \bbb<sup>{K}_{r}(\xb_t)</sup> \in \mathbb{R}<sup>{d_h},</sup> \ab<sup>{V}_{r}(\xb_t)</sup> \in \mathbb{R}<sup>h$, and$\bbb<sup>{V}_{r}(\xb_t)</sup></sup> \in \mathbb{R}<sup>{d_h}$.

The latent factor maps are given by:

$\ab<sup>{Q}_{r}(\xb_t)</sup> = \bW<sup>{a<sup>Q}_{r}\,\xb_t</sup></sup> \in \mathbb{R}<sup>h,</sup> \quad \bbb<sup>Q_{r}(\xb_t)</sup> = \bW<sup>{b<sup>Q}_{r}\,\xb_t</sup></sup> \in \mathbb{R}<sup>{d_h}.$After$\Qb,\Kb,\Vb$are factorized, multi-head attention proceeds as in standard Transformers, with:$head_i</sup> = \operatorname{Softmax}\Bigl( \tfrac{1}{\sqrt{d_h} \,\Qb_{i} \, (\Kb_{i})<sup>\top</sup> \Bigr) \;\Vb_{i},$where$\Qb_{i}, \Kb_{i}, \Vb_{i} \in \mathbb{R}<sup>{T</sup> \times d_h}$ are the slices along the head dimension.

For RoPE integration, the paper suggests pre-rotating the token-dimension factors:

$\tilde\Bb_K(\xb_t) \;\longleftarrow\;</p> <h1 class='paper-heading' id='operatorname-rope-_t-bigl-bb_k-xb_t-bigr-40-qb_t-41-operatorname-rope-qb-_t'>\operatorname{RoPE}_t\bigl(\Bb_K(\xb_t)\bigr).$A key theorem states that RoPE's relative translational property is preserved in **TPA**. If$\Qb_t$is factorized by **TPA**, then$\operatorname{RoPE}({\Qb}_t)</h1> <p>\frac{1}{R_Q} \Ab_{Q}(\xb_t)<sup>\top</sup> \,\widetilde{\Bb}<em>{Q}(\xb_t),$where$\widetilde{\Bb}</em>{Q}(\xb_t) = \operatorname{RoPE}<em>t\bigl(\Bb</em>{Q}(\xb_t)\bigr)$.

The memory cost per token in TPA is (RK+RV)(h+dh)(\,R_K + R_V\,)\,\bigl(h + d_h\bigr), which can be significantly lower than the standard caching cost of 2hdh2\,h\,d_h.

The paper demonstrates how MHA, MQA, and GQA can be unified as non-contextual variants of TPA. Specifically, standard MHA can be viewed as a specific instance of TPA in which: 1) the rank is set equal to the number of heads; 2) the head dimension factor is non-contextual; 3) the token dimension factor is a linear function of $\xb_t$.

In MQA, all heads share a single set of keys/values, corresponding to RK=RV=1R_K = R_V = 1 along the head dimension, while GQA partitions hh heads into GG groups, each sharing keys/values within that group.

The T\,(T6) architecture, which utilizes TPA in place of standard MHA or GQA, is also detailed. The feed-forward network (FFN) adopts a SwiGLU layer, and RoPE is applied to the $\Qb$ and $\Kb$.

Experiments were conducted on the FineWeb-Edu 100B dataset, comparing T6 against the baseline Llama architecture with SwiGLU activation and RoPE embeddings, as well as Llama variants that replace MHA with MQA, GQA, or MLA. Models were trained at small (124M parameters), medium (353M), and large (773M) scales using the AdamW optimizer.

Results indicate that TPA and its variant TPA-KVonly converge as fast as or faster than the baselines while achieving visibly lower final losses. In downstream evaluations on standard benchmarks, TPA generally ties or outperforms all competing methods.

The paper concludes that TPA offers a flexible, memory-efficient alternative to standard multi-head attention, advancing the scalability of modern LLMs.

Dice Question Streamline Icon: https://streamlinehq.com

Follow-up Questions

We haven't generated follow-up questions for this paper yet.

Github Logo Streamline Icon: https://streamlinehq.com
Youtube Logo Streamline Icon: https://streamlinehq.com

HackerNews

  1. Tensor Product Attention Is All You Need (160 points, 103 comments)
Reddit Logo Streamline Icon: https://streamlinehq.com