Papers
Topics
Authors
Recent
2000 character limit reached

GrokAlign: Geometric Characterisation and Acceleration of Grokking (2506.12284v1)

Published 14 Jun 2025 in cs.LG and stat.ML

Abstract: A key challenge for the machine learning community is to understand and accelerate the training dynamics of deep networks that lead to delayed generalisation and emergent robustness to input perturbations, also known as grokking. Prior work has associated phenomena like delayed generalisation with the transition of a deep network from a linear to a feature learning regime, and emergent robustness with changes to the network's functional geometry, in particular the arrangement of the so-called linear regions in deep networks employing continuous piecewise affine nonlinearities. Here, we explain how grokking is realised in the Jacobian of a deep network and demonstrate that aligning a network's Jacobians with the training data (in the sense of cosine similarity) ensures grokking under a low-rank Jacobian assumption. Our results provide a strong theoretical motivation for the use of Jacobian regularisation in optimizing deep networks -- a method we introduce as GrokAlign -- which we show empirically to induce grokking much sooner than more conventional regularizers like weight decay. Moreover, we introduce centroid alignment as a tractable and interpretable simplification of Jacobian alignment that effectively identifies and tracks the stages of deep network training dynamics. Accompanying \href{https://thomaswalker1.github.io/blog/grokalign.html}{webpage} and \href{https://github.com/ThomasWalker1/grokalign}{code}.

Summary

  • The paper introduces GrokAlign that accelerates grokking by enforcing Jacobian alignment to enhance training efficiency.
  • It demonstrates that a low-rank Jacobian structure improves optimality and robustness, as supported by theoretical proofs and empirical results.
  • The study presents centroid alignment as a practical proxy to monitor training dynamics and control grokking in deep networks.

Delayed generalization and emergent robustness, collectively known as grokking, are key challenges in training deep networks. This paper (2506.12284) investigates the geometric properties of deep networks during training to understand and accelerate this phenomenon.

The authors propose that grokking is realized through the alignment of a deep network's Jacobian matrices with the training data. Specifically, a deep network is considered Jacobian-aligned at a point x\mathbf{x} if its Jacobian Jx(f)J_{\mathbf{x}}(f) is a rank-one matrix of the form cx⊤\mathbf{c}\mathbf{x}^\top for some vector c\mathbf{c}. The paper proves that, under a constraint on the Frobenius norm of the Jacobian and a zero-bias assumption, a Jacobian-aligned network is optimal in terms of minimizing common loss functions like cross-entropy or mean-squared error. Furthermore, Jacobian alignment, particularly when combined with the low-rank bias often observed in deep network training dynamics, implies optimal robustness to ℓ2\ell_2 perturbations according to Theorem 2.

Based on this theoretical link between Jacobian alignment, optimality, robustness, and the empirical observation that training biases Jacobians towards low rank (as shown by the increasing explained variance of the first principal component of the Jacobian over training, Figure 1), the paper introduces GrokAlign. GrokAlign is a regularization method that explicitly encourages Jacobian alignment by adding a term proportional to the Frobenius norm of the Jacobians evaluated at the training data points to the training loss:

LGrokAlign=Ltask+λJac⋅1m∑p=1m∥Jxp(f)∥F2\mathcal{L}_{\text{GrokAlign}} = \mathcal{L}_{\text{task}} + \lambda_{\text{Jac}} \cdot \frac{1}{m} \sum_{p=1}^m \|J_{\mathbf{x}_p}(f)\|_F^2

where Ltask\mathcal{L}_{\text{task}} is the original loss function, mm is the number of training samples, and λJac\lambda_{\text{Jac}} is a weighting coefficient. By constraining the Jacobian norm, GrokAlign ensures that optimizing the task loss naturally leads to Jacobian alignment, thereby inducing grokking.

Implementing GrokAlign requires computing the Frobenius norm of the Jacobian matrices. Directly computing the full Jacobian can be computationally expensive, especially for high-dimensional inputs or wide networks. The paper notes that GrokAlign utilizes an approximation of the Frobenius norm for practical efficiency, referencing prior work on Jacobian regularization for robustness [hoffman_robust_2019].

To efficiently monitor the training dynamics and the emergence of grokking, the paper introduces the concept of centroid alignment. For a continuous piecewise affine network (a broad class including ReLU networks), the functional geometry is a partition of the input space into linear regions. Each region is associated with parameters, including a centroid μx\mu_{\mathbf{x}}. The paper proves (Theorem 3) that the centroid μx\mu_{\mathbf{x}} is related to the Jacobian by μx=(Jx(f))⊤1\mu_{\mathbf{x}}=(J_{\mathbf{x}}(f))^\top\mathbf{1}, where 1\mathbf{1} is a vector of ones. Crucially, centroids can be computed efficiently using Jacobian-vector products (JVPs), which are generally faster than computing full Jacobians. A network is centroid-aligned at x\mathbf{x} if μx=cx\mu_{\mathbf{x}}=c\mathbf{x} for some scalar cc. Proposition 1 shows that Jacobian alignment implies centroid alignment, making centroid alignment a tractable and interpretable proxy metric for tracking the state of the network geometry.

The connection between centroid alignment and the feature learning regime is established via the Neural Tangent Kernel (NTK). For a two-layer scalar-output ReLU network, the time derivative of the inner product between an input x\mathbf{x} and its centroid ⟨x,μx⟩\langle\mathbf{x},\mu_{\mathbf{x}}\rangle is shown to be a weighted sum of the NTK between x\mathbf{x} and the training data points (Theorem 5). A changing inner product implies a dynamic NTK, characteristic of the feature learning regime. Thus, centroid alignment dynamics reflect the feature learning process.

Practical applications and implementation insights are demonstrated through experiments:

  1. Centroid Alignment as a Monitor: On an MNIST binary classification task, the inner product ⟨x,μx⟩\langle\mathbf{x},\mu_{\mathbf{x}}\rangle changes along with the NTK dynamics, confirming that centroid alignment tracks the feature learning regime (Figure 2). Monitoring centroid alignment can indicate when feature learning is occurring and potentially when training should continue.
  2. Accelerating Robustness: On a high-dimensional XOR task designed for grokking, centroid alignment tracks the onset of delayed robustness. Standard training with weight decay eventually leads to alignment and robustness after centroid norms decrease. However, applying GrokAlign significantly accelerates both centroid alignment and the emergence of robustness by directly regularizing the Jacobian norm (Figure 3). This highlights GrokAlign's ability to directly control the desired geometric property.
  3. Controlling Grokking Dynamics:
    • Inducing Robustness: Training a CNN on CIFAR10 with GrokAlign induces higher centroid alignment and significantly better adversarial robustness compared to weight decay alone (Figures 4 and 5). The plateauing of centroid alignment can signal convergence of the robustness property.
    • Inhibiting Grokking: By designing a custom GrokAlign term that constrains the Jacobian norm to remain high rather than low, grokking (generalization) can be inhibited. This shows the flexibility of Jacobian regularization in controlling dynamics, monitored effectively by centroid norms and alignment (Figure 4).
    • Accelerating Grokking: Compared to weight decay, Grokfast [lee_grokfast_2024], and adversarial training [tan_understanding_2024], GrokAlign achieves significantly faster grokking (measured by time/steps to reach 85% test accuracy) on the standard MNIST grokking setup (Table 1). GrokAlign achieves this by promoting Jacobian alignment, unlike adversarial training in this setup (Table 2).
  4. Application to Transformers: Although the core theory is for piecewise affine networks, centroid dynamics can still be analyzed for other architectures like Transformers. On a modular addition task where Transformers can learn either an algorithmic or classification solution, GrokAlign biases the model towards a classification-style solution (associated with lower Gini coefficients of embedding matrices), which is not the natural grokking path for this task (Figure 5). When the task is modified to favor the classification solution (by fixing embeddings), GrokAlign effectively accelerates grokking (Figure 6). This suggests GrokAlign is best suited when the desired solution aligns with the geometric properties it encourages.

Implementation Considerations:

  • Computational Cost: Calculating the full Jacobian for large networks and datasets is prohibitive. Practical implementations of GrokAlign rely on approximations of the Frobenius norm (e.g., randomized trace estimators) or leveraging efficient JVP computations for related metrics like centroids. Libraries like PyTorch and TensorFlow support JVP via automatic differentiation (torch.autograd.grad(outputs, inputs, grad_outputs=vec)). The centroid μx\mu_{\mathbf{x}} can be computed as (Jx(f))⊤1(J_{\mathbf{x}}(f))^\top\mathbf{1}, which is a JVP where the gradient is taken with respect to the input x\mathbf{x} and the grad_outputs is set to 1\mathbf{1}.
  • Hyperparameter Tuning: The regularization strength λJac\lambda_{\text{Jac}} is critical for GrokAlign's effectiveness and needs to be tuned based on the task, network architecture, and other hyperparameters (like learning rate, weight decay).
  • Network Architecture: While demonstrated on CNNs and Transformers, the strongest theoretical backing is for continuous piecewise affine networks. Applying GrokAlign to other architectures may require empirical validation and understanding how their Jacobians behave.
  • Bias Terms: The core theoretical results rely on a zero-bias assumption. Practical networks usually have biases. Experiments sometimes omit biases or show results on networks with biases, suggesting the principles might hold more broadly, but this is a theoretical limitation.
  • Centroid Monitoring: Centroid alignment (cosine similarity between μx\mu_{\mathbf{x}} and x\mathbf{x}) and centroid norm provide valuable insights into training dynamics and can guide decisions on hyperparameter tuning or training duration.

In summary, the paper provides a geometric explanation for grokking rooted in Jacobian alignment and introduces GrokAlign as a principled regularization method to enforce this alignment and accelerate grokking. The centroid alignment metric, efficiently computable via JVPs, offers a practical tool for monitoring these dynamics and understanding the transition to the grokked state.

Whiteboard

Video Overview

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 5 tweets with 168 likes about this paper.