How Transformers Learn Regular Language Recognition: A Theoretical Study on Training Dynamics and Implicit Bias
(2505.00926v3)
Published 2 May 2025 in cs.LG, cs.CL, and stat.ML
Abstract: Language recognition tasks are fundamental in NLP and have been widely used to benchmark the performance of LLMs. These tasks also play a crucial role in explaining the working mechanisms of transformers. In this work, we focus on two representative tasks in the category of regular language recognition, known as even pairs' andparity check', the aim of which is to determine whether the occurrences of certain subsequences in a given sequence are even. Our goal is to explore how a one-layer transformer, consisting of an attention layer followed by a linear layer, learns to solve these tasks by theoretically analyzing its training dynamics under gradient descent. While even pairs can be solved directly by a one-layer transformer, parity check need to be solved by integrating Chain-of-Thought (CoT), either into the inference stage of a transformer well-trained for the even pairs task, or into the training of a one-layer transformer. For both problems, our analysis shows that the joint training of attention and linear layers exhibits two distinct phases. In the first phase, the attention layer grows rapidly, mapping data sequences into separable vectors. In the second phase, the attention layer becomes stable, while the linear layer grows logarithmically and approaches in direction to a max-margin hyperplane that correctly separates the attention layer outputs into positive and negative samples, and the loss decreases at a rate of $O(1/t)$. Our experiments validate those theoretical results.
Summary
The paper shows a two-phase training process where the transformer first extracts task-relevant features via attention and then refines classification through margin maximization.
It details how token scores and attention weights evolve differentially for even pairs and parity check tasks, highlighting specific dynamics in early training.
The study introduces Chain-of-Thought strategies that leverage even pairs training to enable zero-shot parity check, underscoring practical task decomposition.
This paper theoretically studies how a simple one-layer transformer, consisting of an attention layer followed by a linear layer, learns to solve two representative regular language recognition tasks: even pairs and parity check. The goal is to understand the underlying working mechanisms by analyzing the training dynamics under gradient descent.
The tasks involve determining if the total number of occurrences of specific subsequences in a binary string is even. Even pairs checks for ab or ba, while parity check checks for b. Although even pairs seems more complex, it is equivalent to checking if the first and last tokens are the same (O(1)). Parity check generally requires scanning the whole sequence (O(L)).
The model takes a sequence of token embeddings as input. A crucial aspect of the embedding strategy is that token embeddings at different positions are orthogonal. The one-layer transformer computes attention weights based on the interaction of token embeddings with learned query and key matrices (specifically, Wq and Wk combined into W), and then applies a learned value matrix (Wv) and a linear layer (Wu). For simplicity, the authors reparameterize WuWv as u and WkWq as W, resulting in the model output Tθ(X)=u⊤ℓ=1∑Lxℓφℓ, where φℓ are the softmax attention weights involving W. The model is trained using logistic loss for binary classification.
Training is performed using a two-phase gradient descent (GD) strategy. An early phase (t≤t0) uses a learning rate η for u and ηλ for W, while a later phase (t>t0) uses η for both, where λ is the attention scaling factor. This schedule is presented as an approximation of Adam's behavior in early vs. later training steps.
Even Pairs Problem
For the even pairs task, the paper identifies two distinct training phases:
Phase 1 (Rapid Growth): Both the linear layer (u) and attention layer (W) parameters grow rapidly.
Linear Layer (u) Dynamics: The "token score" ⟨ut,E1w⟩ for the first token grows quickly (Θ(ηt)), while scores for later tokens ⟨ut,Eℓw⟩ (ℓ≥2) become increasingly negative (−Θ(η2t2)). The score for the second token becomes more negative than others (ℓ≥3). This reflects an early focus on the beginning of the sequence due to simple length-1 positive examples and the distinct behavior of the first two positions in longer sequences for this task.
Attention Layer (W) Dynamics: The attention scores ⟨Eiw,WtELw′⟩ evolve such that the attention weight φ1 for the first token increases in "positive" samples (where first and last tokens are the same) and decreases in "negative" samples. Conversely, attention weights for non-leading tokens (ℓ≥2) increase in negative samples. Attention specifically shifts focus to the second token over later non-leading tokens (ℓ≥3). This indicates the attention mechanism is learning to differentiate between samples based on whether the first/last tokens match, and to prioritize the second token when the first doesn't match the expected pattern.
Outcome of Phase 1: By the end of Phase 1, the attention layer maps the input sequences into a feature space where the positive and negative samples are linearly separable. The linear layer u=E1a+E1b−E2a−E2b can separate these attention outputs.
Phase 2 (Margin Maximization and Implicit Bias): Training continues to refine the linear layer while the attention layer becomes relatively stable.
Linear Layer (u) Dynamics: The norm ∥ut∥ continues to grow logarithmically (Ω(logt)). The direction of ut converges to the max-margin hyperplane that separates the attention layer's outputs from the end of Phase 1. This is the implicit bias of gradient descent on linearly separable data.
Attention Layer (W) Dynamics: The attention weights φℓ(n,t) change negligibly ( bounded change ∥Wt−Wt0∥≤O(1) due to the scaling factor λ). The attention patterns learned in Phase 1 (focus on first/second tokens) persist.
Loss Convergence: The logistic loss converges to the global minimum sublinearly, at a rate of O(1/t). A large enough λ (specifically Ω(Lmax2/ϵ3) for target loss ϵ) is needed to ensure low loss.
In practice, this suggests that for tasks requiring comparing specific tokens (like first and last), a single attention layer can learn to perform this comparison and route information accordingly, while the linear layer then acts as a simple classifier on this compressed, task-relevant representation. The two phases highlight a separation of concerns: attention learns features, and the linear layer learns to classify based on those features, driven by margin maximization.
Parity Check Problem
The parity check problem is harder for a single transformer pass. The paper proposes two approaches leveraging the connection between parity check and even pairs via Chain-of-Thought (CoT):
Approach 1: Inference via Truncated CoT: This method uses a transformer already trained for the even pairs task without any CoT training. It performs iterative inference:
Given a sequence X=w1⋯wL, at each step t=1,…,L−1, the transformer checks the even pairs condition on the current sequence (which changes over steps).
The predicted label ('a' for even, 'b' for odd parity) is appended to the sequence.
The first token of the sequence is removed to maintain length (truncated CoT).
After L−1 steps, the label appended at the final step is the predicted parity of the original sequence.
This approach demonstrates that a transformer trained for a simpler task (even pairs) can perform a more complex one (parity check) in a zero-shot manner (no extra training) when combined with a clever iterative inference strategy inspired by finite automata.
Approach 2: Training with CoT under Teacher Forcing: This approach trains a one-layer transformer end-to-end to perform the CoT steps.
Training Data Generation: For an input sequence X=w1⋯wL0, CoT steps t=1,…,L0−1 are generated. At step t, the input is Xt=w1⋯wL0+t−1, and the label is the parity of wt and wL0+t−1. The sequence for the next step is Xt+1=Xt with the label wL0+t appended. This creates training data pairs (Xt,yt) where yt is the label for comparing token wt and wL0+t−1.
Training Objective: The total loss combines a CoT loss (LCoT) for sequences of length L0≤L≤2L0−1 (labeled for CoT steps) and a regularization loss (LReg) for sequences of length L<L0 (labeled for even pairs). The regularization loss is found to stabilize training and initialize parameters to learn parity check effectively, leveraging the similarity to even pairs.
Training Dynamics: The training also exhibits two phases, similar to the even pairs case.
Phase 1: The linear layer u dynamics are similar (first token positive, others negative, second token most negative among non-leading), driven by the regularization loss on shorter sequences. The attention layer W learns different patterns for L≥L0 sequences: it focuses on the token at position L−L0+1 (which needs to be compared with the last token wL) and the last token wL.
Phase 2: Similar to even pairs, the attention layer stabilizes, and the linear layer's direction converges to a max-margin solution for separating the attention layer outputs on the combined CoT and regularization datasets. The total loss decays sublinearly.
Implementation Considerations:
Model Simplicity: The analysis uses a one-layer transformer, keeping the theoretical analysis tractable. Real-world applications use deeper models, but the insights into attention learning specific token relationships and linear layers acting as classifiers on these features could generalize.
Embedding: The specific orthogonal embedding Eℓa=e2ℓ−1, Eℓb=e2ℓ is crucial for the theoretical analysis. Practical systems use learned embeddings, which might introduce complexity but could potentially learn similar orthogonal-like representations for positions/tokens.
Scaling Parameter λ: The analysis shows λ plays a critical role, especially in Phase 2, in stabilizing the attention layer dynamics and ensuring loss convergence. Choosing an appropriate λ is important for stable training and achieving low error.
Training Data: For theoretical guarantees, the analysis often assumes training on the full dataset of all sequences up to Lmax (or specific relevant subsets for CoT). Real-world data is sampled, requiring consideration of stochastic gradient descent and generalization.
Computational Requirements: The analysis is for a shallow model and synthetic data, suggesting low computational needs for the experiments presented. Scaling to complex tasks and deeper models would require significant resources typical of LLMs.
CoT Implementation: The two CoT approaches present different deployment strategies. Inference-only CoT leverages a pre-trained model flexibly but requires careful design of the iterative inference process. Trained CoT requires generating specific training data for the intermediate steps, potentially increasing training data size and complexity.
Experimental Validation: Experiments on synthetic data validate the theoretical findings, showing the predicted loss decay rate, the growth/decay patterns of token scores at different positions, and the dynamics of attention scores on relevant tokens (first/second for Even Pairs, relevant CoT tokens for Parity Check). Additional experiments confirm the two-phase dynamic with constant learning rates and even in a larger model (NanoGPT), suggesting the phenomenon is not merely an artifact of the chosen LR schedule or small synthetic task.
In summary, this paper provides a detailed theoretical account of how a one-layer transformer learns specific regular language tasks, highlighting a two-phase training process where attention learns relevant token relationships and the linear layer performs margin maximization. It also introduces novel CoT-based strategies for the more challenging parity check task, demonstrating how task decomposition can enable simpler models or inference procedures. The results offer insights into the implicit biases and feature learning capabilities of transformers on structured sequential data.