The paper introduces the Differential Transformer (Diff Transformer), an architecture designed to mitigate the issue of over-allocation of attention to irrelevant context in standard Transformers. The core innovation lies in the differential attention mechanism, which computes attention scores as the difference between two separate attention maps.
The differential attention mechanism involves partitioning the query and key vectors into two groups and computing two separate attention maps. The attention scores are then calculated as the difference between these two maps. This subtraction aims to cancel out noise and promote sparse attention patterns, allowing the model to focus on relevant context.
The mathematical formulation of the differential attention operator is given by: , ,
where:
- $X \in \mathbb{R}^{N \times d_{\text{model}$ is the input
- are query and key projections
- is the value projection
- $W^Q, W^K, W^V \in \mathbb{R}^{d_{\text{model} \times 2d}$ are parameter matrices
- is a learnable scalar
To stabilize learning dynamics, is re-parameterized as: $\lambda = \exp( \mathbf{\lambda_{q_1} \cdot \mathbf{\lambda_{k_1} ) - \exp( \mathbf{\lambda_{q_2} \cdot \mathbf{\lambda_{k_2} ) + \lambda_{\text{init}$,
where:
- $\mathbf{\lambda_{q_1}, \mathbf{\lambda_{k_1}, \mathbf{\lambda_{q_2}, \mathbf{\lambda_{k_2} \in \mathbb{R}^{d}$ are learnable vectors
- $\lambda_{\text{init} \in (0,1)$ is a constant for initialization.
In multi-head differential attention, the outputs of individual heads are normalized using and scaled by to align gradients with the standard Transformer architecture.
Experimental results demonstrate that Diff Transformer outperforms Transformer in various LLMing tasks. Scaling experiments indicate that Diff Transformer requires approximately 65% of the model size or training tokens compared to Transformer to achieve comparable performance.
The paper presents results on downstream tasks, including long-context modeling, key information retrieval, hallucination mitigation, in-context learning, and reduction of activation outliers. Diff Transformer exhibits notable advantages in these practical applications. For instance, in key information retrieval, Diff Transformer shows superior accuracy in retrieving information from long contexts, particularly when the relevant information is located in the first half of the context. The paper also evaluates contextual hallucination in text summarization and question answering, finding that Diff Transformer mitigates hallucination compared to Transformer. For in-context learning, Diff Transformer enhances accuracy and demonstrates greater robustness to order permutations in demonstration examples.
Furthermore, Diff Transformer reduces outliers in model activations, offering potential benefits for quantization. Attention logits and hidden states exhibit lower top activation values compared to Transformer, indicating fewer activation outliers.
Ablation studies validate the design choices of Diff Transformer. Removing GroupNorm degrades performance, highlighting its importance in normalizing diverse statistics between heads. The performance is robust to different initialization strategies for .