Multi-Token Attention (2504.00927v1)
Abstract: Soft attention is a critical mechanism powering LLMs to locate relevant parts within a given context. However, individual attention weights are determined by the similarity of only a single query and key token vector. This "single token attention" bottlenecks the amount of information used in distinguishing a relevant part from the rest of the context. To address this issue, we propose a new attention method, Multi-Token Attention (MTA), which allows LLMs to condition their attention weights on multiple query and key vectors simultaneously. This is achieved by applying convolution operations over queries, keys and heads, allowing nearby queries and keys to affect each other's attention weights for more precise attention. As a result, our method can locate relevant context using richer, more nuanced information that can exceed a single vector's capacity. Through extensive evaluations, we demonstrate that MTA achieves enhanced performance on a range of popular benchmarks. Notably, it outperforms Transformer baseline models on standard LLMing tasks, and on tasks that require searching for information within long contexts, where our method's ability to leverage richer information proves particularly beneficial.
Summary
- The paper introduces a novel mechanism using convolution over queries and keys to capture multi-token patterns beyond standard attention.
- It details key-query convolution and head mixing operations that synthesize local context and facilitate cross-head interactions.
- The integration of Group Normalization with depth scaling ensures training stability while adding minimal parameter overhead.
Multi-Token Attention (MTA) is an attention mechanism designed to enhance the ability of LLMs to locate relevant information within a context by addressing limitations inherent in standard single-token attention (2504.00927). Traditional attention mechanisms compute weights based on the similarity between individual query and key vectors, potentially creating an information bottleneck when relevance depends on patterns spanning multiple tokens or the combination of information detected by different attention heads. MTA introduces convolutional operations over queries, keys, and heads to allow attention weights to be conditioned on multiple vectors simultaneously, thereby capturing richer contextual information.
Methodology of Multi-Token Attention
MTA extends the standard multi-head attention framework by incorporating convolutional layers to process attention logits or weights. The primary components are Key-Query Convolution and Head Mixing Convolution, often supplemented by Group Normalization with Depth Scaling.
Key-Query Convolution
This component allows the attention score for a query-key pair (qi,kj) to be influenced by neighboring query vectors (e.g., qi−1,qi−2,...) and neighboring key vectors (e.g., kj−1,kj+1,...) within the same attention head. The goal is to capture local dependencies and multi-token patterns that define relevance.
A 2D convolution with a learnable kernel θ and specified kernel sizes (cq for the query dimension, ck for the key dimension) is applied to the attention logits A^=QK⊤/d before the softmax operation (pre-softmax convolution, the default configuration) or to the attention weights A=Softmax(A^) after the softmax (post-softmax convolution).
Pre-softmax Convolution Equation (Conceptual):
A=Softmax(Conv2dθ(A^))
Crucially, causality must be maintained in decoder settings. The convolution is masked to prevent query qi from attending to keys kj where j>i. Furthermore, the convolution operation on the query dimension is masked such that only preceding queries (relative to qi) influence the output for qi. This ensures that information does not flow from future positions. Zero-padding is used for boundary conditions.
The effect is that the attention weight aij becomes a function of a neighborhood of query-key dot products around the (i,j) position, rather than just the single dot product qi⋅kj. For example, with cq=4 and ck=5, the computation for aij can incorporate information from qi,qi−1,qi−2,qi−3 and kj−2,kj−1,kj,kj+1,kj+2 (subject to causality constraints). Each head learns its own kernel parameters θ, enabling specialization in detecting different local patterns. Learned kernels can exhibit specific structures, like diagonal patterns, indicating a focus on matching short sequences rather than just individual tokens.
Head Mixing Convolution
This component facilitates interaction and information sharing across different attention heads before the final output projection. It allows the model to combine evidence detected by separate heads at the attention weight level.
Heads are grouped (e.g., into groups of size ch). A convolution is applied across the head dimension within each group. For non-overlapping groups, this is equivalent to applying a small, shared fully-connected layer across the attention logits or weights of the heads in that group.
Post-softmax Head Mixing Example (Conceptual, 2 heads):
Let A1 and A2 be the attention weights for heads 1 and 2 in a group. The mixed weights Anew1,Anew2 are computed as:
Anew1=w11A1+w12A2
Anew2=w21A1+w22A2
where w are learned kernel weights. Pre-softmax mixing applies the same linear combination principle to the logits A^1,A^2.
This allows the model to synthesize findings. For instance, if one head focuses on occurrences of "term A" and another on "term B", head mixing can amplify attention weights in positions where both terms are indicated as relevant by the respective heads, effectively performing a localized logical AND operation at the attention level. This can be interpreted as increasing the rank and expressive capacity of the attention mechanism.
Integration and Normalization
The default MTA configuration applies key-query convolution pre-softmax and head mixing post-softmax. However, other arrangements are possible, including applying both pre-softmax (potentially via a single 3D convolution over key, query, and head dimensions) or both post-softmax.
To improve training stability, especially in deep networks, Group Normalization with layer-dependent depth scaling is applied to the output of each head, following approaches like ReLoG (2404.19466). This normalization is applied after the attention weights (potentially mixed) have been computed and used to weight the value vectors.
Comparison with Single-Token Attention
MTA differs significantly from the standard multi-head attention mechanism:
Feature | Standard Single-Token Attention | Multi-Token Attention (MTA) |
---|---|---|
Attention Basis | Based on similarity of a single query qi and key kj. | Based on patterns across multiple queries (qi−i′) and keys (kj−j′) via key-query convolution. |
Info. Bottleneck | High: qi and kj must encode all necessary context. | Reduced: Relevance emerges from local multi-token context and patterns captured by convolutional kernels. |
Cross-Head Interaction | Occurs late, only after weighted value aggregation via Wo. | Occurs earlier, directly on attention logits/weights via head mixing convolution, allowing synergistic combinations. |
Mechanism | Dot product, softmax, masking. | Dot product, convolution (key-query, head), softmax, masking, optional group norm + scaling. |
Parametric Cost | Dominated by WQ,WK,WV,Wo projections. | Adds parameters for convolutional kernels (θ for key-query, w for head mixing), but these are typically small. |
Computational Cost | Generally lower, often leveraging highly optimized kernels. | Higher due to convolutions, especially if not fused or optimized. The cost depends on kernel sizes cq,ck,ch. |
Capability | Excels at point-wise similarity matching. | Designed to excel at matching local sequence patterns and combining evidence from multiple conceptual detectors (heads). |
Enhanced Information Leverage
MTA's design directly targets the limitations of single-token attention, enabling it to leverage richer information sources:
- Multi-Token Pattern Recognition: The key-query convolution moves beyond simple token-to-token similarity. By applying learned kernels over the matrix of dot products, it can detect and assign higher attention scores based on the presence of specific n-gram-like patterns or relative positional arrangements between the query context and the key context. This is beneficial for tasks involving phrasal matching, syntactic dependencies, or identifying locations where multiple specific items co-occur within a short window. A toy task demonstrated MTA's ability to identify blocks containing two specific characters, where standard attention failed.
- Cross-Head Synthesis: Head mixing allows the model to explicitly combine information streams represented by different heads during the attention calculation. If heads specialize (e.g., one detects semantic category, another detects syntactic role), head mixing can compute attention weights that reflect the conjunction or other logical combination of these specialized detections, leading to more precise targeting of contextually relevant information.
- Improved Contextual Disambiguation: By considering neighboring tokens (via key-query conv) and parallel analyses from different heads (via head mixing), the attention decision becomes less reliant on the potentially overloaded representations within single query/key vectors. The surrounding context and outputs from multiple specialized detectors contribute directly to the weight calculation, improving robustness and precision.
The paper reports that MTA yields performance improvements over baseline Transformer models on standard LLMing benchmarks and demonstrates particular strengths on tasks requiring information localization within long contexts, where its capacity to utilize richer, multi-token information proves advantageous (2504.00927).
Implementation Considerations
Implementing MTA involves substituting the standard attention module with one incorporating the described convolutions.
- Convolution Implementation: Standard deep learning library functions (e.g.,
torch.nn.Conv2d
) can be used. Careful implementation of masking is required for causality in decoder models, especially for the key-query convolution applied pre-softmax. Padding is necessary to handle sequence boundaries. - Computational Overhead: The convolutions add computational cost compared to standard dot-product attention. The overhead scales with sequence length squared (like standard attention) but also with the kernel sizes (cq,ck,ch). For long sequences, this can be significant, although current implementations may not yet benefit from the level of optimization applied to standard attention kernels. Fusing operations or using specialized kernels could mitigate this.
- Parameter Overhead: The number of additional parameters from the convolutional kernels is generally small compared to the parameters in the projection matrices (WQ,WK,WV,Wo).
- Hyperparameters: The kernel sizes (cq,ck) for key-query convolution and the group size (ch) for head mixing are key hyperparameters influencing the receptive field and degree of cross-head interaction. The choice of pre- vs. post-softmax application for each convolution type also affects behavior.
- Training Stability: The inclusion of Group Normalization with depth scaling suggests that, like other complex attention variants, MTA might benefit from specific normalization strategies to ensure stable training, particularly in deeper models.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import torch import torch.nn as nn import torch.nn.functional as F class MTA_Attention(nn.Module): def __init__(self, d_model, n_heads, kq_kernel_size=(3, 3), head_group_size=2): super().__init__() self.d_model = d_model self.n_heads = n_heads self.d_head = d_model // n_heads self.kq_kernel_size = kq_kernel_size # (c_q, c_k) self.head_group_size = head_group_size self.Wq = nn.Linear(d_model, d_model) self.Wk = nn.Linear(d_model, d_model) self.Wv = nn.Linear(d_model, d_model) self.Wo = nn.Linear(d_model, d_model) # Learnable kernel for key-query convolution (per head) self.kq_conv = nn.Conv2d( in_channels=self.n_heads, out_channels=self.n_heads, kernel_size=self.kq_kernel_size, padding=((kq_kernel_size[0]-1)//2, (kq_kernel_size[1]-1)//2), # Example padding groups=self.n_heads # Depthwise separable style - one kernel per head ) # Learnable mixing for head convolution (grouped) num_groups = self.n_heads // self.head_group_size self.head_mix_conv = nn.Conv1d( in_channels=num_groups * self.head_group_size, # Process all heads together out_channels=num_groups * self.head_group_size, kernel_size=1, # Equivalent to linear layer across grouped heads groups=num_groups # Apply mixing within each group ) # NOTE: Actual implementation needs careful grouping logic # Simplified GroupNorm - actual implementation needs per-head application + scaling self.group_norm = nn.GroupNorm(self.n_heads, d_model) def forward(self, query, key, value, mask=None): B, T_q, _ = query.shape B, T_k, _ = key.shape q = self.Wq(query).view(B, T_q, self.n_heads, self.d_head).permute(0, 2, 1, 3) # (B, n_heads, T_q, d_head) k = self.Wk(key).view(B, T_k, self.n_heads, self.d_head).permute(0, 2, 3, 1) # (B, n_heads, d_head, T_k) v = self.Wv(value).view(B, T_k, self.n_heads, self.d_head).permute(0, 2, 1, 3) # (B, n_heads, T_k, d_head) # 1. Compute raw attention scores (logits) attn_logits = torch.matmul(q, k) / (self.d_head ** 0.5) # (B, n_heads, T_q, T_k) # --- MTA Modifications --- # 2. Apply Key-Query Convolution (Pre-Softmax) # NOTE: Requires careful causal masking implementation! This is simplified. # Pad causally if needed before conv. attn_logits_kq_conv = self.kq_conv(attn_logits) # (B, n_heads, T_q, T_k) # Apply causal mask AFTER convolution if pre-softmax if mask is not None: attn_logits_kq_conv = attn_logits_kq_conv.masked_fill(mask == 0, -1e9) # ------------------------- # 3. Apply Softmax attn_weights = F.softmax(attn_logits_kq_conv, dim=-1) # (B, n_heads, T_q, T_k) # --- MTA Modifications --- # 4. Apply Head Mixing Convolution (Post-Softmax) # Reshape for 1D conv across heads: (B, n_heads, T_q*T_k) -> (B*T_q*T_k, n_heads, 1) ? No, needs careful reshaping based on grouping. # Conceptually: Mix heads within groups. Reshape weights (B, n_heads, T_q, T_k) -> (B, T_q, T_k, n_heads) -> apply grouped linear layer/conv1d # This part requires careful tensor manipulation for grouping. # attn_weights_mixed = self.head_mix_conv(attn_weights.view(...)).view(B, n_heads, T_q, T_k) # Placeholder attn_weights_mixed = attn_weights # Skipping head mixing for simplicity in pseudocode # -------------------------- # 5. Compute weighted values output = torch.matmul(attn_weights_mixed, v) # (B, n_heads, T_q, d_head) # 6. Concatenate heads and project output = output.permute(0, 2, 1, 3).contiguous().view(B, T_q, self.d_model) # (B, T_q, d_model) # Apply Group Norm + Scaling (Simplified) # output = self.group_norm(output.permute(0, 2, 1)).permute(0, 2, 1) # Needs correct dims output = self.Wo(output) return output |
Conclusion
Multi-Token Attention offers a modification to the standard attention mechanism by incorporating learnable convolutions over keys, queries, and heads. This allows attention weights to depend on local neighborhoods of tokens and combinations of signals from different heads, addressing the single-token information bottleneck. While introducing computational overhead, MTA demonstrates potential for improved performance, especially in tasks sensitive to multi-token patterns and long-range information retrieval, by enabling a richer, more contextually informed attention process.
Related Papers
Tweets
YouTube
HackerNews
- Multi-Token Attention (151 points, 44 comments)
- Multi-Token Attention (79 points, 5 comments)