Cottention: Linear Transformers with Cosine Attention
- Cottention is a cosine-based attention mechanism that normalizes queries and keys to achieve linear memory complexity.
- It replaces quadratic softmax attention with cosine similarity, enabling efficient long-sequence and causal decoding.
- Empirical benchmarks on models like BERT and GPT-J demonstrate near-parity in accuracy with significant memory and latency improvements.
Cottention denotes an attention mechanism for transformers that replaces the softmax-based scoring kernel with a cosine similarity kernel and exploits the resulting associativity to achieve native linear—and, for causal decoding, constant—inference-time memory with respect to sequence length. Developed as an alternative to traditional softmax attention, which imposes quadratic memory complexity that limits scalability on long sequences, Cottention demonstrates comparable expressivity on standard benchmarks while offering substantial memory and potential computational savings. The concept and architecture for Cottention were systematically presented and evaluated by Mongaras et al. in "Cottention: Linear Transformers With Cosine Attention" (Mongaras et al., 2024).
1. Motivation: Limitations of Softmax Attention in Transformers
Transformers leveraging self-attention have achieved state-of-the-art results across natural language processing and related domains, in part owing to the expressivity of the softmax-normalized dot-product attention kernel: where , , for batch size , number of heads , sequence length , and key/value dimensionality , . The time and, more critically, memory cost of this mechanism, due to explicit storage of the attention map for every head, becomes impractical for large , particularly during inference.
Cottention addresses this bottleneck by dispensing with the softmax normalization in favor of a cosine similarity kernel, enabling algebraic rearrangements that directly yield resource-efficient computation, crucial for long-sequence or streaming contexts (Mongaras et al., 2024).
2. Mathematical Formulation and Core Algorithm
Cottention replaces the softmax attention kernel by computing row-normalized queries and keys, followed by matrix multiplication: Here, each query and key vector is -normalized row-wise. Cosine similarity is thus computed as the dot product of unit vectors: To mitigate the scale growth of summed similarities with sequence length, a scalar parameter is trained per head; the output is stabilized by dividing through (where is the sigmoid function), yielding:
By associativity, one can compute first (shape ), then multiply by , bypassing memory storage and reducing memory to . For bidirectional attention, this yields linear scaling in .
3. Causal Masking, RNN Reformulation, and Inference Efficiency
For autoregressive (causal) attention, direct factorization is blocked by the triangular mask. The Cottention algorithm circumvents this by reformulating causal attention computation as a recurrent neural network:
- The hidden state at step is , tracked by:
- The output for token is:
For streaming or stepwise inference, only need be stored and updated, so total memory remains , independent of . This property eliminates the need for storing or recomputing the full past , tensors (“kv-caching”) required by softmax attention.
A custom CUDA kernel implements this algorithm with one thread block per head-row and per-step accumulations, storing only floats per head, enabling low-latency inference.
4. Computational Complexity Analysis
Cottention’s memory and time complexity are outlined in the following table:
| Mechanism | Training Memory | Inference Memory (causal) | Time per step |
|---|---|---|---|
| Softmax attention | |||
| Cottention (bidirectional) | |||
| Cottention (causal, inf.) | (const.) | (const.) |
Bidirectional Cottention provides linear memory in ; in causal (autoregressive) inference, the memory footprint is constant in , whereas softmax always requires cache.
5. Empirical Evaluation and Benchmarking
Cottention was benchmarked as a drop-in replacement for softmax attention in both BERT (bidirectional) and GPT-J (causal) architectures. Empirical results show:
- On GLUE for BERT, Cottention attains scores within approximately 1.3 points of standard softmax attention (average), indicating near-parity in downstream task accuracy.
- In GPT-J next-token prediction experiments on The Pile, both 300M and 1.2B parameter models achieve final perplexities nearly identical to softmax attention (e.g., 1.2B: softmax 9.5, Cottention 9.6).
- Empirical measurements on A100 GPUs confirm the predicted linear/constant scaling of memory usage with sequence length for Cottention, versus quadratic for softmax.
- Wall-clock times favor Cottention for long sequences (when ), though for high and short softmax’s lower multiplicative work can yield slightly lower training times.
Stabilization hyperparameters converge to $0.1$–$0.2$ per head after training from an initialization at $0.5$, indicating reduced reliance on normalization at convergence.
6. Implementation and Practical Details
Mongaras et al. provide a fully detailed CUDA kernel for Cottention, exploiting fused operations and memory locality to minimize both peak memory and compute time. Backpropagation is handled via a closed-form reversal of the forward cumulative-sum steps for , , gradients. No intermediate or arrays are stored; only the minimal recurrent state is maintained throughout.
This design supports easy integration into existing transformer codebases as a drop-in replacement for standard attention modules.
7. Implications and Future Directions
Cottention is distinguished by its ability to match the modeling capacity of softmax attention while reducing memory scaling, especially at inference where constant memory enables long-context generation and streaming. The RNN perspective suggests synergies with continual or online transformers and potentially hybrid architectures incorporating LSTM- or GRU-like gating on the incremental state.
Future work includes scaling Cottention to >$10$B parameter models, optimizing kernel-level compute to close any remaining throughput gaps versus specialized fast-attention implementations (e.g., FlashAttention), experimenting with alternative normalization schedules, and exploiting Cottention’s algebraic structure for low-rank or matrix-factorized key–value pathways.
This reconceptualization of the attention mechanism paves the way to more efficient transformer architectures, especially for resource-constrained or real-time sequence modeling scenarios (Mongaras et al., 2024).