Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
167 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
42 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

NeuralGrok: Accelerate Grokking by Neural Gradient Transformation (2504.17243v2)

Published 24 Apr 2025 in cs.LG and cs.AI

Abstract: Grokking is proposed and widely studied as an intricate phenomenon in which generalization is achieved after a long-lasting period of overfitting. In this work, we propose NeuralGrok, a novel gradient-based approach that learns an optimal gradient transformation to accelerate the generalization of transformers in arithmetic tasks. Specifically, NeuralGrok trains an auxiliary module (e.g., an MLP block) in conjunction with the base model. This module dynamically modulates the influence of individual gradient components based on their contribution to generalization, guided by a bilevel optimization algorithm. Our extensive experiments demonstrate that NeuralGrok significantly accelerates generalization, particularly in challenging arithmetic tasks. We also show that NeuralGrok promotes a more stable training paradigm, constantly reducing the model's complexity, while traditional regularization methods, such as weight decay, can introduce substantial instability and impede generalization. We further investigate the intrinsic model complexity leveraging a novel Absolute Gradient Entropy (AGE) metric, which explains that NeuralGrok effectively facilitates generalization by reducing the model complexity. We offer valuable insights on the grokking phenomenon of Transformer models, which encourages a deeper understanding of the fundamental principles governing generalization ability.

Summary

  • The paper introduces a neural gradient transformation using a neural-amplifier to significantly accelerate grokking in transformer models.
  • It employs a bilevel optimization framework that optimizes both the transformer and an auxiliary MLP, achieving up to 4.67x speedup on challenging arithmetic tasks.
  • It demonstrates that learned gradient rescaling stabilizes generalization and reduces model complexity, as evidenced by lower absolute gradient entropy scores.

This paper introduces NeuralGrok, a method designed to accelerate the "grokking" phenomenon in transformer models trained on arithmetic tasks (2504.17243). Grokking refers to the delayed generalization observed long after a model has achieved near-perfect training accuracy (overfitting). NeuralGrok tackles this by learning an optimal transformation for the gradients during training.

Core Idea: Learned Gradient Transformation

The central concept is to use an auxiliary neural network, termed the "neural-amplifier," to modify the gradients of the main transformer model before they are used for parameter updates. This transformation is learned dynamically to prioritize gradient components that contribute most effectively to generalization.

Implementation: Bilevel Optimization

NeuralGrok employs a bilevel optimization framework:

  1. Inner Loop:
    • The main transformer model (M(θ)M(\theta)) is trained on a subset of the training data (Dinner\mathcal{D}_{inner}).
    • Stochastic gradients gt=θL(θt,Bt)g_t = \nabla_{\theta} \mathcal{L}(\theta_t, B_t) are computed for a mini-batch BtDinnerB_t \subset \mathcal{D}_{inner}.
    • The neural-amplifier (G(φ)G(\varphi)), parameterized by φ\varphi, transforms these gradients: gt=G(gt,φt)g'_t = G(g_t, \varphi_t).
    • The transformer parameters are updated using the transformed gradients: θt+1OptM(θt,gt,ηθ,t)\theta_{t+1} \leftarrow Opt_M(\theta_t, g'_t, \eta_{\theta,t}).
  2. Outer Loop:
    • Performed every TT inner loop steps.
    • The goal is to optimize the neural-amplifier parameters φ\varphi.
    • This is done by minimizing the loss of the updated transformer model (after a hypothetical inner loop step using the current amplifier) on a separate validation subset (Douter\mathcal{D}_{outer}).
    • The gradient for the amplifier is φL(θηθG(φ,gθ),Douter)\nabla_{\varphi} \mathcal{L}(\theta - \eta_{\theta} G(\varphi, g_{\theta}), \mathcal{D}_{outer}).
    • The amplifier parameters are updated: φt+1OptG(φt,φL,ηφ,t)\varphi_{t+1} \leftarrow Opt_G(\varphi_t, \nabla_{\varphi} \mathcal{L}, \eta_{\varphi,t}).

Neural-Amplifier Architecture and Transformation

  • The neural-amplifier is implemented as a simple Multi-Layer Perceptron (MLP). The paper uses a 3-layer MLP for simpler tasks and potentially adjusts based on task complexity.
  • The transformation involves two steps (Equation 2):

    1. Rotation: The MLP outputs weights which are passed through a softmax function to get a probability distribution p=softmax(MLPφ(g))p = \text{softmax}(MLP_{\varphi}(g)). This distribution pp is used to re-weight the original gradient gg element-wise (pgp \odot g).
    2. Rescaling: The re-weighted gradient is normalized and then scaled by a constant factor cc. The final transformed gradient is g=cpgpg2g' = c \cdot \frac{p \odot g}{\| p \odot g \|_2}. The default value used is c=1.0c=1.0.

The PyTorch-like pseudocode for the amplifier's forward pass is provided in the paper's Appendix (Figure \ref{code:neuralgrad}):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class NeuralGrok(nn.Module):
    # Simplified version based on paper description
    def __init__(self, input_dim, hidden_dim=32, n_layers=3, c=1.0):
        super(NeuralGrok, self).__init__()
        self.c = c
        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
        for _ in range(n_layers - 2):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()])
        layers.append(nn.Linear(hidden_dim, input_dim)) # Output dim matches input grad dim
        self.mlp = nn.Sequential(*layers)
        self.softmax = nn.Softmax(dim=-1) # Assuming gradients are flattened

    def forward(self, grad):
        # Assuming grad is flattened or handled appropriately dimension-wise
        mlp_out = self.mlp(grad)
        p = self.softmax(mlp_out)
        transformed_grad = p * grad
        norm = torch.norm(transformed_grad, p=2)
        # Avoid division by zero
        if norm > 1e-8:
             g_prime = self.c * transformed_grad / norm
        else:
             g_prime = torch.zeros_like(grad)
        return g_prime
Note: The paper's code example processes each gradient entry independently (nn.Linear(1, ...)), implying the MLP might operate element-wise or require reshaping gradients. The pseudocode above assumes a standard MLP operating on the flattened gradient vector.

Experimental Setup and Results

  • Tasks: Modular arithmetic tasks of varying complexity (e.g., (a+b) mod 97, (a*c+b*d-e) mod 7).

  • Data Split: 50% train, 50% test. The training set is further split 49:1 into Dinner\mathcal{D}_{inner} : Douter\mathcal{D}_{outer}.
  • Base Model: Decoder-only transformer (2 layers for simpler tasks, 4 layers for the complex one).
  • Baselines: Standard training (with weight decay), GrokFast-MA, GrokFast-EMA (2405.20233).
  • Key Findings:
    • Acceleration: NeuralGrok significantly reduces the steps needed to reach 95% test accuracy compared to baselines (e.g., up to 2.95x vs standard, 4.67x vs GrokFast-MA on the hardest task).
    • Stability: NeuralGrok (like GrokFast-MA) provides more stable generalization curves compared to standard training and GrokFast-EMA, which showed accuracy drops post-generalization.
    • Robustness: NeuralGrok performs well across different gradient rescaling factors cc (0.2 to 2.0), although larger cc can slightly delay generalization.
    • Gradient Rescaling: The paper finds that simple gradient normalization (rescaling gradients to constant norm c without the learned rotation) also stabilizes standard training and accelerates grokking, suggesting it's an effective regularization technique for these tasks, potentially better than high weight decay alone.

Complexity Analysis with Absolute Gradient Entropy (AGE)

  • The paper proposes Absolute Gradient Entropy (AGE) as a metric to track learning complexity: H(G)=giGgilngiH(\mathcal{G}) = -\sum_{g_i \in \mathcal{G}} |g_i| \ln{|g_i|}.
  • AGE increases during the memorization phase and decreases during the generalization phase, correlating well with the grokking transition.
  • NeuralGrok results in lower AGE (and Absolute Weight Entropy, AWE) scores compared to baselines, indicating it promotes lower complexity solutions. The transformed gradients produced by NeuralGrok have lower AGE than the original gradients.

Practical Considerations and Limitations

  • Implementation Complexity: Requires implementing the bilevel optimization loop and the neural-amplifier module. Careful management of the two optimizers and data splits is necessary.
  • Computational Overhead: Adds the cost of the amplifier's forward and backward passes, plus the outer loop optimization step.
  • Hyperparameter Tuning: Introduces new hyperparameters like the amplifier architecture, the outer loop frequency TT, and the rescaling factor cc.
  • Regularization Insights: Suggests gradient norm rescaling is a potent regularizer for grokking in arithmetic tasks. NeuralGrok combines this with a learned directional change.
  • Transferability: The learned gradient transformations are found to be highly task-specific and do not transfer well even between similar arithmetic operations (e.g., addition to subtraction).
  • Scope: Experiments are limited to synthetic arithmetic tasks. Applicability to broader domains like LLMs is proposed as future work.
  • Weight Decay: The paper notes that standard weight decay can sometimes destabilize or impede learning on challenging arithmetic tasks, especially with larger values. Combining small weight decay with gradient normalization seemed effective.

In summary, NeuralGrok offers a practical method to accelerate generalization in grokking scenarios by learning to adaptively transform gradients using a bilevel optimization framework and an auxiliary MLP. Its implementation involves managing two optimization loops and an extra network module, but it demonstrates significant speedups and improved stability on arithmetic tasks, providing insights into the role of gradient manipulation and complexity reduction in achieving generalization.

X Twitter Logo Streamline Icon: https://streamlinehq.com