Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
41 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
41 tokens/sec
o3 Pro
7 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

nGPT: Normalized Transformer with Representation Learning on the Hypersphere (2410.01131v1)

Published 1 Oct 2024 in cs.LG and cs.AI

Abstract: We propose a novel neural network architecture, the normalized Transformer (nGPT) with representation learning on the hypersphere. In nGPT, all vectors forming the embeddings, MLP, attention matrices and hidden states are unit norm normalized. The input stream of tokens travels on the surface of a hypersphere, with each layer contributing a displacement towards the target output predictions. These displacements are defined by the MLP and attention blocks, whose vector components also reside on the same hypersphere. Experiments show that nGPT learns much faster, reducing the number of training steps required to achieve the same accuracy by a factor of 4 to 20, depending on the sequence length.

The paper "nGPT: Normalized Transformer with Representation Learning on the Hypersphere" (Loshchilov et al., 1 Oct 2024 ) introduces a modified Transformer architecture where key vector representations are constrained to reside on the surface of a unit hypersphere. This constraint is enforced through explicit L2L_2 normalization applied to token embeddings, hidden states, and the vectors constituting the weight matrices of the attention and MLP blocks (along the embedding dimension). The core hypothesis is that operating on this manifold simplifies the optimization landscape and leads to faster convergence.

Normalization Strategy and Architectural Changes

The central modification in nGPT is the pervasive application of L2L_2 normalization. Unlike standard Transformers which rely on LayerNorm or RMSNorm applied before attention or MLP blocks, nGPT removes these entirely. Instead, normalization is applied after computations and updates:

  1. Matrix Normalization: After each optimizer step, all weight matrices (WinputW_{\text{input}}, WoutputW_{\text{output}}, Wq,Wk,Wv,WoW_q, W_k, W_v, W_o, Wu,Wν,WoMLPW_u, W_\nu, W_{o\text{MLP}}) are normalized such that the L2L_2 norm of vectors along the embedding dimension ($d_{\text{model}$) is 1. For a matrix WRd1×d2W \in \mathbb{R}^{d_1 \times d_2}, where d2=dmodeld_2 = d_{\text{model}}, this means normalizing each of the d1d_1 row vectors (if embedding is the second dimension) or d1d_1 column vectors (if embedding is the first dimension, depending on convention) to have unit norm. This normalization occurs outside the forward/backward pass, directly modifying the weights after the gradient update.
  2. Hidden State Normalization: The hidden state vector hRdmodelh \in \mathbb{R}^{d_{\text{model}}} is explicitly normalized to unit norm after the attention and MLP block updates within the forward pass. This ensures the "information carrier" always remains on the hypersphere.
  3. Removal of Standard Normalization: LayerNorm and RMSNorm layers are completely removed from the architecture.
  4. Removal of Weight Decay: Because matrix rows/columns are constantly renormalized to unit norm, their magnitude is controlled. Consequently, weight decay (L2 regularization) is deemed unnecessary and removed (setting weight_decay=0 in the optimizer, making Adam equivalent to AdamW).
  5. Removal of Learning Rate Warmup: The paper reports successful training without learning rate warmup schedules.

Modified Update Rule and Optimization Perspective

The standard residual connection hh+SubLayer(h)h \leftarrow h + \text{SubLayer}(h) is replaced with a modified update rule that incorporates learnable per-dimension scaling and explicit normalization. This frames the layer update as a step on the hypersphere.

Let hh be the input hidden state to a block (Attention or MLP), also assumed to be normalized (h2=1||h||_2 = 1). Let hsuggestion=Norm(SubLayer(h))h_{\text{suggestion}} = \text{Norm}(\text{SubLayer}(h)) be the normalized output of the sub-layer's core computation (e.g., ATTN(h)\text{ATTN}(h) or MLP(h)\text{MLP}(h)). The update rule is:

hNorm(h+α(hsuggestionh))h \leftarrow \text{Norm}(h + \bm{\alpha} \odot (h_{\text{suggestion}} - h))

Here:

  • Norm(x)=x/x2\text{Norm}(x) = x / ||x||_2 denotes L2L_2 normalization.
  • α\bm{\alpha} is a learnable vector of size $d_{\text{model}$ (distinct for Attention, αA\bm{\alpha}_{\text{A}}, and MLP, $\bm{\alpha}_{\text{M}$). These are termed "eigen learning rates".
  • \odot denotes element-wise multiplication.
  • (hsuggestionh)(h_{\text{suggestion}} - h) represents the update direction suggested by the sub-layer.
  • The α\bm{\alpha} vector scales the contribution of this update direction along each dimension.
  • The final Norm()\text{Norm}(\cdot) projects the result of the scaled update back onto the unit hypersphere, acting as a retraction step in manifold optimization.

This formulation can be interpreted as a variable-metric optimization step on the hypersphere, where α\bm{\alpha} represents the diagonal elements of a metric tensor that adapts the step size along different dimensions. The paper suggests this constrained optimization on the manifold contributes to the observed faster convergence.

Scaling Factors for Degrees of Freedom

Since normalization forces all vectors to have unit magnitude, potentially losing important scaling information, nGPT introduces several learnable scaling factors to restore these degrees of freedom:

  1. Logit Scaling (szs_z): A learnable scalar szs_z scales the final output logits before the softmax function: logits=sz(hfinalWoutputT)\text{logits} = s_z \cdot (h_{\text{final}} W_{\text{output}}^T). This controls the sharpness or confidence of the final probability distribution.
  2. Query-Key Scaling (sqks_{qk}): In the attention mechanism, the dot product becomes softmax((hWqT)(hWkT)Tdk)\text{softmax}(\frac{(h W_q^T)(h W_k^T)^T}{\sqrt{d_k}}). With normalized Wq,WkW_q, W_k, the magnitude of projected Q and K vectors is bounded. nGPT optionally normalizes Q and K vectors themselves after projection and introduces a learnable scalar sqks_{qk}. The attention formula potentially changes (details vary slightly in the paper/appendix) but often involves scaling Q and K, e.g., sqkQs_{qk} \cdot Q and sqkKs_{qk} \cdot K, and modifying the softmax denominator, sometimes using dk\sqrt{d_k} instead of 1/dk1/\sqrt{d_k}. The exact implementation may involve scale_qk * Q @ K.T / sqrt(dk) or other variations depending on whether Q/K are normalized post-projection. An important detail is scaling the softmax input up by dk\sqrt{d_k} before the softmax, effectively reversing the standard scaling, justified by the bounded cosine similarities from normalized vectors.
  3. MLP Scaling (su,sνs_u, s_{\nu}): Within the SwiGLU MLP variant (Wu,WνW_u, W_\nu projections), learnable scalars su,sνs_u, s_{\nu} scale the outputs before the SiLU activation and element-wise multiplication: MLP(h)=((suhWuTσ(sνhWνT))WoMLPT\text{MLP}(h) = ((s_u \cdot h W_u^T \cdot \sigma(s_{\nu} \cdot h W_{\nu}^T)) W_{o\text{MLP}}^T. A fixed scaling factor of dmodel\sqrt{d_{\text{model}}} is also applied to the input of the SiLU activation (σ\sigma) to ensure its argument has sufficient variance to operate in its non-linear regime, as hWνTh W_{\nu}^T involves normalized vectors.

These scaling factors are learned during training alongside other parameters. Ablation studies show that while beneficial, simplifying or fixing some of these scales (e.g., fixing sqks_{qk}, sus_u, sνs_{\nu} to 1, or using a single scalar szs_z) results in only minor performance degradation.

Implementation Considerations and Computational Cost

Implementing nGPT requires modifying standard Transformer code:

  • Add normalization functions for vectors/matrices. Matrix normalization typically happens after optimizer.step().
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    
    def normalize_matrix_rows(W):
      # W shape: [dim_out, dim_in] (normalize rows along dim_in)
      norm = torch.linalg.vector_norm(W, dim=1, keepdim=True)
      W.data /= norm
      return W
    
    # After optimizer step:
    # with torch.no_grad():
    #   model.transformer.wte.weight = normalize_matrix_rows(model.transformer.wte.weight)
    #   model.lm_head.weight = normalize_matrix_rows(model.lm_head.weight)
    #   for block in model.transformer.h:
    #     block.attn.c_attn.weight = normalize_matrix_rows(block.attn.c_attn.weight) # Assuming fused QKV
    #     block.attn.c_proj.weight = normalize_matrix_rows(block.attn.c_proj.weight)
    #     block.mlp.c_fc1.weight = normalize_matrix_rows(block.mlp.c_fc1.weight) # Assuming fused Up/Gate in SwiGLU
    #     block.mlp.c_proj.weight = normalize_matrix_rows(block.mlp.c_proj.weight)
  • Replace residual connections with the hypersphere update rule, incorporating learnable α\bm{\alpha} vectors.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    
    # Inside a Transformer block's forward pass
    # h is the input state (assumed normalized)
    # self.alpha_A is the learnable alpha vector for attention
    
    attn_output = self.attn(h) # Core attention computation
    h_suggestion_A = F.normalize(attn_output, dim=-1)
    h = F.normalize(h + self.alpha_A * (h_suggestion_A - h), dim=-1)
    
    # Similarly for MLP using self.alpha_M
    mlp_output = self.mlp(h) # Core MLP computation
    h_suggestion_M = F.normalize(mlp_output, dim=-1)
    h = F.normalize(h + self.alpha_M * (h_suggestion_M - h), dim=-1)
    
    return h
  • Remove LayerNorm/RMSNorm layers.
  • Initialize and learn the scaling factors (sz,sqk,su,sνs_z, s_{qk}, s_u, s_{\nu}) and α\bm{\alpha} vectors.
  • Configure the optimizer with weight_decay=0 and potentially remove the learning rate warmup schedule.

A significant practical consideration is the computational overhead. The explicit normalization operations (especially matrix normalization after each step and hidden state normalizations within each layer) add computational cost. The paper reports a 60-80% increase in time per training step compared to a baseline GPT, depending on context length. This overhead arises from the normalization calls and memory transfers, which are currently not heavily optimized in standard deep learning frameworks. The authors suggest that fused kernels could mitigate this, and the relative overhead might decrease for very large models where matrix multiplications dominate computation time. However, the drastically reduced number of required training steps is argued to outweigh this per-step cost in terms of total training time and compute.

Experimental Results and Performance

The primary claim of nGPT is significantly accelerated convergence. Experiments on the OpenWebText dataset using 0.5B and 1B parameter models show:

  • Faster Convergence: nGPT reaches target validation loss levels using substantially fewer training steps (and tokens processed) compared to a baseline GPT model:
    • 4x fewer steps for 1k context length.
    • 10x fewer steps for 4k context length.
    • 20x fewer steps for 8k context length.
  • Downstream Performance: This faster convergence translates to faster achievement of comparable performance on downstream tasks (ARC-E, HellaSwag, WinoGrande, etc.).
  • Numerical Stability: nGPT matrices (embeddings, attention, MLP projections) exhibit significantly lower condition numbers compared to the baseline GPT, suggesting better-behaved, less degenerate representations and potentially improved numerical stability during training.
  • Length Extrapolation: When evaluated on the PG19 dataset with sequences longer than the training context length (8k), nGPT showed more stable perplexity compared to the baseline GPT, without requiring specific positional encoding modifications like RoPE adjustments typically needed for extrapolation.
  • Learned Parameters: The eigen learning rates (αA\bm{\alpha}_{\text{A}}, $\bm{\alpha}_{\text{M}$) learn modest average values (around 0.2-0.37), indicating controlled step sizes. The scaling factors (sz,sqk,su,sνs_z, s_{qk}, s_u, s_{\nu}) also converge to non-trivial values, confirming their role in restoring necessary scale information.

While the wall-clock time per step is higher, the substantial reduction in the number of steps needed makes nGPT potentially much faster overall for achieving a target performance level, especially at longer sequence lengths.

Ablation Studies

Ablations confirmed the utility of the introduced components:

  • Removing or simplifying scaling factors led to slight performance drops, indicating their usefulness but also suggesting potential for simplification (e.g., using fixed scales or removing optional Q/K normalization).
  • The hypersphere update mechanism and matrix normalizations were crucial for the observed speedups.

Conclusion

nGPT proposes a modification to the Transformer architecture centered around explicit L2L_2 normalization of representations and weight matrices, framing the learning process as optimization on a hypersphere. This approach demonstrably accelerates convergence by a significant factor (4x-20x fewer steps) across different context lengths, albeit with an increased computational cost per step. The improved numerical stability and length extrapolation capabilities are additional potential benefits. The core trade-off lies between the reduced number of training iterations and the increased cost per iteration.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (4)
  1. Ilya Loshchilov (18 papers)
  2. Cheng-Ping Hsieh (9 papers)
  3. Simeng Sun (23 papers)
  4. Boris Ginsburg (111 papers)
Citations (2)
Youtube Logo Streamline Icon: https://streamlinehq.com