Provable Scaling Laws of Feature Emergence from Learning Dynamics of Grokking (2509.21519v3)
Abstract: While the phenomenon of grokking, i.e., delayed generalization, has been studied extensively, it remains an open problem whether there is a mathematical framework that characterizes what kind of features will emerge, how and in which conditions it happens, and is closely related to the gradient dynamics of the training, for complex structured inputs. We propose a novel framework, named $\mathbf{Li_2}$, that captures three key stages for the grokking behavior of 2-layer nonlinear networks: (I) \underline{\textbf{L}}azy learning, (II) \underline{\textbf{i}}ndependent feature learning and (III) \underline{\textbf{i}}nteractive feature learning. At the lazy learning stage, top layer overfits to random hidden representation and the model appears to memorize. Thanks to lazy learning and weight decay, the \emph{backpropagated gradient} $G_F$ from the top layer now carries information about the target label, with a specific structure that enables each hidden node to learn their representation \emph{independently}. Interestingly, the independent dynamics follows exactly the \emph{gradient ascent} of an energy function $E$, and its local maxima are precisely the emerging features. We study whether these local-optima induced features are generalizable, their representation power, and how they change on sample size, in group arithmetic tasks. When hidden nodes start to interact in the later stage of learning, we provably show how $G_F$ changes to focus on missing features that need to be learned. Our study sheds lights on roles played by key hyperparameters such as weight decay, learning rate and sample sizes in grokking, leads to provable scaling laws of feature emergence, memorization and generalization, and reveals the underlying cause why recent optimizers such as Muon can be effective, from the first principles of gradient dynamics. Our analysis can be extended to multi-layer architectures.
Paper Prompts
Sign up for free to create and run prompts on this paper using GPT-5.
Top Community Prompts
Explain it Like I'm 14
What this paper is about (big picture)
This paper tries to explain grokking: a strange training behavior where a neural network first seems to memorize the training data (does great on training, poor on test), and only much later suddenly learns the real rule and starts to generalize (does great on new, unseen data). The authors build a simple, math-based framework—called Li—to show how and why this happens, what kinds of internal “features” the network learns, and how training settings (like weight decay, learning rate, width, and data size) control when memorization turns into real understanding.
What questions the paper asks
In simple terms, the paper asks:
- Which patterns (features) does a neural network learn during grokking?
- How exactly do those features appear during training?
- What training conditions make useful features more likely to emerge?
- How much data is needed before the network switches from memorizing to generalizing?
- Why do certain optimizer choices (like Muon) help, and how do they encourage diverse features?
How the authors paper the problem (methods, in everyday language)
The authors look closely at training dynamics of a small, two-layer neural network (input → hidden layer → output). They split learning into three stages and analyze what the feedback signal (“gradient”) tells each part of the network to do over time.
Think of the network as a team:
- The output layer is like a fast note-taker that tries to fit the answers from whatever “random sketches” the hidden layer gives it.
- The hidden layer is like a group of students trying to discover real patterns—but they need helpful guidance (signal) from the output layer to do that.
Here are the three stages in the Li framework:
1) Lazy learning (memorization)
- At first, the output layer quickly fits the training data using the hidden layer’s random features. This looks like memorizing answers rather than understanding rules.
- A key detail: a regularization trick called weight decay slightly “leaks” useful signal backwards from the output to the hidden layer. Without this leak, the hidden layer wouldn’t get the hints it needs to learn real patterns.
2) Independent feature learning
- Once the output layer has memorized, the backprop signal it sends to the hidden layer starts carrying information about the labels (thanks to weight decay).
- Each hidden neuron now learns independently—like students working on their own projects—guided by a simple score the authors define, called an energy function E.
- E is just a number that gets higher when a hidden neuron’s feature aligns well with the task. Training nudges each neuron uphill on E, so the neuron discovers a feature that boosts this score.
3) Interactive feature learning
- Later, hidden neurons start noticing each other. If two neurons learn the same feature, the math pushes them apart (they “repel”) so the network doesn’t waste capacity duplicating the same idea.
- The feedback signal also changes to focus on what’s missing: if some important features are already learned, the signal points the remaining neurons toward the features that still need to be discovered. This is like a teacher saying, “We’ve got addition, now learn subtraction.”
To make the math concrete, the authors paper “group arithmetic” tasks (like modular addition: add numbers but wrap around at a fixed size M). That’s a common testbed for grokking. They also run experiments to check that the theory matches what actually happens.
What they found and why it matters
Here are the main takeaways, with short explanations of why they’re important:
- Weight decay makes grokking possible
- Without weight decay, the hidden layer doesn’t get useful hints after the output layer memorizes. With weight decay, the hints “leak through,” and the hidden layer can start discovering real features. This explains why grokking shows up when regularization is used.
- Three-stage learning: memorize first, then understand
- The network first uses random hidden features to fit the training set (Stage I).
- Then each hidden neuron independently climbs the energy function E to learn a specific, meaningful feature (Stage II).
- Finally, neurons coordinate: they spread out to cover different features and focus on what’s still missing (Stage III).
- Features = peaks of the energy function
- The features that emerge are exactly the “local peaks” of E—patterns that score highly for the task.
- In tasks like modular addition, these features match known mathematical building blocks (like specific frequencies, similar to notes in music), which are efficient and generalize well.
- How much data do you need? A scaling law
- You don’t need all possible training examples to learn the real rule. For a task size M, about M log M examples are enough to keep the “good” feature peaks stable so the model finds them. That’s far fewer than all M² possible pairs.
- In plain terms: as the problem gets bigger, the fraction of data you need actually shrinks (roughly like log M / M), even though the total number of examples grows.
- When and why memorization wins
- If the dataset is too small or unbalanced (e.g., it mostly shows one type of example), the easiest peaks in E become “memorization features.” The model then learns to remember specific pairs instead of discovering the general rule.
- Around the boundary between enough and not-enough data, small learning rates help the model land in the “good” basin (true features), while large learning rates can push it into a memorization basin.
- Optimizer choices can improve feature diversity
- An optimizer called Muon rebalances gradients to avoid piling too many neurons onto the same feature. This encourages neurons to discover different features, which helps the model cover all the needed pieces faster and generalize better.
- Wider isn’t always better
- If the hidden layer is extremely wide, the model can act “too lazy” and stick with its initial random features instead of learning new ones. There’s a sweet spot where the layer is wide enough to learn, but not so wide that it never changes.
- Extends beyond two layers
- The same ideas can be carried to deeper networks: useful signals reach earlier layers first, those layers learn initial features, and then higher layers build on top of them. Residual connections help by making these signals clearer.
Why this matters going forward (implications)
- Better training recipes: The paper gives concrete guidance on settings that trigger grokking in a good way—use weight decay to “leak” helpful signals, pick learning rates carefully near the memorization/generalization boundary, and consider optimizers like Muon to encourage diverse features.
- Predictable data needs: The scaling law (roughly M log M examples) helps plan how much data you need for a task of size M to learn rules, not just memorize.
- Feature-first thinking: Seeing features as peaks of an energy function makes “emergent features” less mysterious. It suggests ways to design architectures and losses that steer learning toward the right peaks and away from memorization.
- Beyond toy tasks: Although tested on group arithmetic like modular addition, the framework is general and can be extended to deeper networks and more complex problems, helping us understand and shape when models truly “get it.”
In short, this paper turns the story of grokking into a step-by-step process with clear signals, simple rules, and provable patterns—making it easier to train models that actually learn the underlying rules instead of just memorizing examples.
Collections
Sign up for free to add this paper to one or more collections.



