- The paper introduces a neural gradient transformation using a neural-amplifier to significantly accelerate grokking in transformer models.
- It employs a bilevel optimization framework that optimizes both the transformer and an auxiliary MLP, achieving up to 4.67x speedup on challenging arithmetic tasks.
- It demonstrates that learned gradient rescaling stabilizes generalization and reduces model complexity, as evidenced by lower absolute gradient entropy scores.
This paper introduces NeuralGrok, a method designed to accelerate the "grokking" phenomenon in transformer models trained on arithmetic tasks (2504.17243). Grokking refers to the delayed generalization observed long after a model has achieved near-perfect training accuracy (overfitting). NeuralGrok tackles this by learning an optimal transformation for the gradients during training.
Core Idea: Learned Gradient Transformation
The central concept is to use an auxiliary neural network, termed the "neural-amplifier," to modify the gradients of the main transformer model before they are used for parameter updates. This transformation is learned dynamically to prioritize gradient components that contribute most effectively to generalization.
Implementation: Bilevel Optimization
NeuralGrok employs a bilevel optimization framework:
- Inner Loop:
- The main transformer model (M(θ)) is trained on a subset of the training data (Dinner).
- Stochastic gradients gt=∇θL(θt,Bt) are computed for a mini-batch Bt⊂Dinner.
- The neural-amplifier (G(φ)), parameterized by φ, transforms these gradients: gt′=G(gt,φt).
- The transformer parameters are updated using the transformed gradients: θt+1←OptM(θt,gt′,ηθ,t).
- Outer Loop:
- Performed every T inner loop steps.
- The goal is to optimize the neural-amplifier parameters φ.
- This is done by minimizing the loss of the updated transformer model (after a hypothetical inner loop step using the current amplifier) on a separate validation subset (Douter).
- The gradient for the amplifier is ∇φL(θ−ηθG(φ,gθ),Douter).
- The amplifier parameters are updated: φt+1←OptG(φt,∇φL,ηφ,t).
Neural-Amplifier Architecture and Transformation
- The neural-amplifier is implemented as a simple Multi-Layer Perceptron (MLP). The paper uses a 3-layer MLP for simpler tasks and potentially adjusts based on task complexity.
- The transformation involves two steps (Equation 2):
- Rotation: The MLP outputs weights which are passed through a
softmax
function to get a probability distribution p=softmax(MLPφ(g)). This distribution p is used to re-weight the original gradient g element-wise (p⊙g).
- Rescaling: The re-weighted gradient is normalized and then scaled by a constant factor c. The final transformed gradient is g′=c⋅∥p⊙g∥2p⊙g. The default value used is c=1.0.
The PyTorch-like pseudocode for the amplifier's forward pass is provided in the paper's Appendix (Figure \ref{code:neuralgrad}):
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
|
class NeuralGrok(nn.Module):
# Simplified version based on paper description
def __init__(self, input_dim, hidden_dim=32, n_layers=3, c=1.0):
super(NeuralGrok, self).__init__()
self.c = c
layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
for _ in range(n_layers - 2):
layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()])
layers.append(nn.Linear(hidden_dim, input_dim)) # Output dim matches input grad dim
self.mlp = nn.Sequential(*layers)
self.softmax = nn.Softmax(dim=-1) # Assuming gradients are flattened
def forward(self, grad):
# Assuming grad is flattened or handled appropriately dimension-wise
mlp_out = self.mlp(grad)
p = self.softmax(mlp_out)
transformed_grad = p * grad
norm = torch.norm(transformed_grad, p=2)
# Avoid division by zero
if norm > 1e-8:
g_prime = self.c * transformed_grad / norm
else:
g_prime = torch.zeros_like(grad)
return g_prime |
Note: The paper's code example processes each gradient entry independently (nn.Linear(1, ...)
), implying the MLP might operate element-wise or require reshaping gradients. The pseudocode above assumes a standard MLP operating on the flattened gradient vector.
Experimental Setup and Results
Complexity Analysis with Absolute Gradient Entropy (AGE)
- The paper proposes Absolute Gradient Entropy (AGE) as a metric to track learning complexity: H(G)=−gi∈G∑∣gi∣ln∣gi∣.
- AGE increases during the memorization phase and decreases during the generalization phase, correlating well with the grokking transition.
- NeuralGrok results in lower AGE (and Absolute Weight Entropy, AWE) scores compared to baselines, indicating it promotes lower complexity solutions. The transformed gradients produced by NeuralGrok have lower AGE than the original gradients.
Practical Considerations and Limitations
- Implementation Complexity: Requires implementing the bilevel optimization loop and the neural-amplifier module. Careful management of the two optimizers and data splits is necessary.
- Computational Overhead: Adds the cost of the amplifier's forward and backward passes, plus the outer loop optimization step.
- Hyperparameter Tuning: Introduces new hyperparameters like the amplifier architecture, the outer loop frequency T, and the rescaling factor c.
- Regularization Insights: Suggests gradient norm rescaling is a potent regularizer for grokking in arithmetic tasks. NeuralGrok combines this with a learned directional change.
- Transferability: The learned gradient transformations are found to be highly task-specific and do not transfer well even between similar arithmetic operations (e.g., addition to subtraction).
- Scope: Experiments are limited to synthetic arithmetic tasks. Applicability to broader domains like LLMs is proposed as future work.
- Weight Decay: The paper notes that standard weight decay can sometimes destabilize or impede learning on challenging arithmetic tasks, especially with larger values. Combining small weight decay with gradient normalization seemed effective.
In summary, NeuralGrok offers a practical method to accelerate generalization in grokking scenarios by learning to adaptively transform gradients using a bilevel optimization framework and an auxiliary MLP. Its implementation involves managing two optimization loops and an extra network module, but it demonstrates significant speedups and improved stability on arithmetic tasks, providing insights into the role of gradient manipulation and complexity reduction in achieving generalization.