Grokking Dynamics in Neural Networks
- Grokking dynamics are a phenomenon where overparameterized neural networks shift from perfect memorization to strong generalization through prolonged training and regularization.
- The Main Embedding Diff (MED) is introduced as a precise progress measure, tracking changes in embedding uniformity that closely align with the onset of generalization.
- Successful generalization depends on comprehensive training data coverage across label equivalence classes, ensuring that embedding uniformity effectively translates to test performance.
Grokking dynamics refer to the characteristic temporal evolution of learning in overparameterized neural networks (and, more broadly, in regularized models) where prolonged perfect memorization of the training set is followed by an abrupt transition to strong test set generalization. This phenomenon, which appears across architectures and loss functions, has prompted considerable theoretical, mechanistic, and empirical investigation. Central to recent work is the explicit recognition that grotesquely delayed generalization—often after millions or billions of optimization steps—arises from interactions between regularization, solution geometry, dataset structure, and embedding representations.
1. Embedding Uniformity, Weight Decay, and the Geometry of Generalization
A foundational insight is that, once a model achieves perfect training accuracy and the loss is dominated by the weight decay regularizer, the optimization trajectory promotes uniformity in the embedding space. Under continued training with strong regularization, the embeddings of tokens (in, for example, modular addition tasks with Transformers) are driven toward a configuration where all tokens collapse to a maximally uniform arrangement, typically distributed over a low-dimensional manifold in the embedding space. This behavior is formalized by results such as:
for sufficiently small (Theorem 3.2), and refined by Theorem 3.1, which shows that for solutions on an algebraic operation task, each equivalence class of outputs (indexed by for operation result modulo prime), the sum clusters near a specific anchor :
However, the uniformity of the embedding space induced by weight decay is not sufficient for grokking. Full generalization requires an additional algebraic or combinatorial property: the training dataset must have adequate coverage of label equivalence classes. If the training set lacks proximity (in the relevant metric, e.g., Manhattan distance in index space) to all possible test points, a model cannot generalize to those test points regardless of embedding geometry. The critical ratio for training set coverage (Theorem 3.4) is:
where is the coverage radius, and is the prime (for modular tasks). Insufficient guarantees the impossibility of perfect test generalization.
2. Progress Measures: Indirectness and the Main Embedding Diff (MED)
Previous research on grokking introduced a variety of post hoc progress measures—based on, e.g., tracking Fourier coefficients of weights, local complexity, spectral properties, or sharpness—to monitor the internal progress of training before the external test loss phase transition. This work critiques such measures on several grounds:
- Indirectness: These metrics do not directly reflect the optimization geometry but correlate with it as symptoms.
- Complexity: Many progress measures, notably those relying on Fourier decomposition or spectral analysis, are computationally intensive and impractical for everyday monitoring.
- Lack of Theoretical Justification: The explanatory power of these measures is contingent on their correlation with the emergence of embedding uniformity, rather than a fundamental connection to the underlying mechanism of grokking.
To address this, a new progress measure, Main Embedding Diff (MED), is introduced:
Here, denotes the embedding mapping after training epochs. Empirical results demonstrate that tracks the generalization transition with much higher fidelity and minimal computational overhead. When the MED approaches zero, the network is at or near the grokking threshold.
3. Dataset Structure, Coverage, and Limits to Generalization
The transition from memorization to generalization is not dictated by the decay of the weight norm or mere regularizer minimization, but is governed by the structural match between the embedding uniformity and the coverage of the training set over the label equivalence classes of the task. Theoretical analysis and experiments demonstrate:
- For every test sample , perfect generalization requires a nearby training sample in the same equivalence class—a statement made formal in Theorem 3.3.
- The critical ratio expresses the minimum proportion of the combinatorial space that must be captured by the training set.
- Tasks with more complex structure (e.g., ) demand higher train set coverage for grokking to manifest.
Under random or poorly structured sampling, the maximal achievable test accuracy abruptly saturates below $1$, regardless of embedding regularity.
4. Experimental Evidence: Modular Arithmetic, ResNet-18, and Beyond
The conclusions are supported by systematic experimental evaluation:
- Transformer on Modular Addition: Single-layer Transformers trained on undergo delayed test accuracy improvement after prolonged embedding uniformization, with the precise onset determined by train set coverage.
- Grokking is absent if training data leaves gaps in equivalence class representation.
- More complex tasks and smaller primes require proportionally more comprehensive data coverage.
- Main Embedding Diff Tracking: Across various primes (), the drop in aligns nearly exactly with the phase transition in test loss.
- ResNet-18 Task: By constructing an analogous structured classification task for images (with labels as sums of quadrant class indices modulo a prime), the dynamics and theoretical predictions are shown to transfer to more conventional deep convolutional architectures.
5. Mathematical Summary and Central Formulas
The main contribution is the clean formalization of both the sufficient statistics for monitoring and predicting grokking, and the conditions under which it can or cannot manifest:
| Phenomenon | Formula | Section |
|---|---|---|
| Embedding uniformity under weight decay | Eq. 7 | |
| Main Embedding Diff (progress measure) | Def. 4.2 | |
| Coverage for test generalization | , with train set ratio, coverage radius | Thm. 3.4 |
6. Broader Implications and Unified Conceptual Framework
A unified theory emerges: grokking is a consequence of weight decay inducing maximal embedding uniformity, which can only yield global generalization if the structure of the training set allows coverage of the full space of label equivalence classes. The MED provides a minimal and direct geometric progress measure, superseding earlier ad hoc approaches.
These findings offer several practical implications:
- Prediction and Monitoring: MED offers a robust tool for diagnosing, monitoring, and understanding the approach to the grokking threshold in training runs.
- Architecture Generality: The theory holds across Transformer and CNN architectures (e.g., ResNet-18), not limited to specialized models or tasks.
- Training Data Design: Guarantees about generalization and the onset of grokking are fundamentally constrained by the combinatorial structure of the training data, rather than model-internal measures such as weight norm.
Collectively, this line of research establishes that genuine grokking is a joint property of geometric regularization (embedding uniformity via weight decay) and algebraic dataset coverage—a necessary and sufficient relationship that provides new clarity on the mechanisms dictating delayed generalization in modern machine learning systems.
Sponsored by Paperpile, the PDF & BibTeX manager trusted by top AI labs.
Get 30 days free