- The paper investigates how Transformer language models track state in permutation tasks, finding they learn specific algorithms like Associative (AA) or Parity-Associative (PAA) rather than step-by-step simulation.
- The study identifies two key mechanisms: Associative (AA) for parallel composition, and Parity-Associative (PAA) which uses a parity heuristic.
- Mechanism choice depends on architecture and pre-training, with Associative models generalizing better, though findings may not extend beyond permutation tasks.
The paper investigates how Transformer LMs track the state of an evolving world, focusing on the task of permutation composition. The central question is whether LMs simulate state evolution step-by-step, use heuristics, or if state tracking is an illusion. The study uses permutation composition as a model system, where LMs predict the final position of objects after a series of swaps. This task is relevant because many state tracking tasks can be reduced to permutation tracking.
The authors analyze LMs by:
- Introducing state tracking problems and the permutation composition task.
- Reviewing interpretability tools for analyzing LM computations.
- Describing potential algorithms LMs might use and their expected signatures.
- Presenting experimental findings on state tracking mechanisms.
The paper identifies two state tracking mechanisms that LMs consistently learn:
- Associative Algorithm (AA): Resembles the associative scan construction.
- Parity-Associative Algorithm (PAA): Uses permutation parity to prune the output space, refining it with an associative scan.
The study finds that LMs do not typically use step-by-step simulation or fully parallel composition mechanisms, even when theoretically implementable. The findings are supported by evidence from representation interventions, probes, error patterns, attention maps, and training dynamics. The choice of mechanism affects model performance on long sequences, and while stochastic, can be influenced by intermediate training tasks that encourage or suppress parity heuristics. The paper suggests a mechanism by which real-world LMs might perform state tracking in language, code, and games.
Background and Preliminaries
The paper addresses the problem of state tracking, which involves inferring common ground, navigating environments, reasoning about code, and playing games. Prior research has explored whether and how LMs perform these tasks, with theoretical work noting the association of state-tracking problems with the complexity class NC1 and empirical work showing that LMs can solve these problems and encode state information.
The study focuses on permutation composition, where LMs are presented with objects and reshuffling operations, requiring them to compute the final order. Permutation tracking is NC1-complete, making it a suitable model for studying state tracking. The finite symmetric group Sn comprises permutations of n objects with a composition operation. Every permutation can be expressed as a composition of two-element swaps, and the parity (even or odd) is determined by the number of swaps.
Given a sequence of permutations, the parity of the tth permutation is denoted as ϵ(at), and the parity of the state St is given by:
ϵ(St)=ϵ(a0...at)=i∑ϵ(ai)mod2
- ϵ(St): Parity of the state St
- ϵ(ai): Parity of the ith permutation
- a0...at: Sequence of permutations
- ∑i: Summation over the index i
- mod : Modulo operation
The LMs are trained to solve the word problem, taking a sequence of actions [a0,...,at] as input and outputting a sequence of state predictions [s0,...,st].
Interpretability Methods
The experiments use interpretability techniques to understand how LMs solve permutation word problems. The internal LM representation at token position t after Transformer layer l is denoted as ht,l, with T and L representing the maximum input length and number of layers, respectively.
Probing
In probing experiments, a smaller probe model (e.g., a linear classifier) maps LM hidden representations h to quantities z. The experiments evaluate whether the state st and the final state parity are linearly encoded in intermediate-layer representations. For each layer l, a state probe predicts p(st∣ht,l), and a parity probe predicts p(ϵ(St)∣ht,l).
Activation Patching
Activation patching determines which representations play a causal role in prediction. Portions of the LM's internal representations are overwritten ("patched") with representations from alternative inputs. The probability that an LM assigns to the output y given an input x, with the representation h replaced by some other representation h′, is denoted as p(y∣x;h←h′).
The normalized logit difference (NLD) is used to measure the shift in prediction toward the clean output y:
NLD=LD(x)−LD(x′)LD(x′;ht,l←hclean)−LD(x′)
where
LD(.)=logp(y∣.)−logp(y′∣.)
- NLD: Normalized Logit Difference
- LD: Logit Difference
- x′: Corrupted input
- x: Clean input
- y: Output from clean input
- y′: Output from corrupted input
- ht,l: Internal LM representation at token position t after Transformer layer l
- hclean: Hidden representation from the clean input x
In prefix patching, all hidden representations up to index t are patched at a layer l.
Theoretical Algorithms
The paper establishes a phenomenology for LM state tracking by identifying candidate algorithms and their empirical signatures.
Sequential Algorithm
The sequential algorithm composes permutations one at a time from left to right, analogous to step-by-step simulation. In this algorithm, each hidden state ht,l stores the associated action at until St can be computed, maintaining ht,t=St.
The patching signature for the sequential algorithm is upper triangular, as any patching experiment that replaces hidden states with l>t will not affect predictions. The probing signature shows a linear dependence on depth, with the state probe accuracy increasing linearly and parity probe accuracy either increasing linearly or remaining constant.
Parallel Algorithm
The word problem on S5 belongs to NC1, requiring circuit depth that scales logarithmically with sequence length, while the word problem on S3 belongs to TC0, with constant-depth threshold circuits.
The patching signature for the parallel algorithm is L-shaped, with interventions at or earlier than layer lp (number of layers) changing predictions, and interventions at deeper layers having no effect. The probing signature shows the probe obtaining perfect accuracy within a constant number of layers, with state parity also computed as an intermediate quantity.
Associative Algorithm (AA)
In the associative algorithm (AA), Transformers compose permutations hierarchically, grouping adjacent sequences and computing their product. This algorithm ensures that ht,l=at−2l+1...at, and thus ht,log(t+1)=a0...at.
The patching signature for AA shows that the length of the prefix that must be modified to alter behavior increases exponentially with depth. The probing signature shows an exponentially increasing state probe accuracy, with parity probe accuracy also increasing exponentially if state parity is encoded in state representations.
Parity-Associative Algorithm (PAA)
The parity-associative algorithm (PAA) computes the final state in two stages: computing the parity of the state and computing the parity complement. Hidden states comprise registers ϵ and K, storing the parity and complement, respectively, i.e., ht,l=(Et,l,Kt,l).
The patching signature for PAA depends on whether the corrupted input has the same parity as the clean input. If the parities differ, prefix patching shows a parallel algorithm signature. If the parities are the same, the patching pattern is AA-like. The probing signature shows state probes improving exponentially with depth, while parity probes converge to 100% at a constant depth.
Experimental Results
The paper compares theoretical mechanisms to empirical properties of LMs trained for permutation tasks, emphasizing that the signatures provide necessary but not sufficient conditions for implementation. Experiments identify algorithmic features shared between idealized mechanisms and transformer behavior, yielding evidence consistent with AA in some models and PAA in others.
Experimental Setup
The authors generated 1 million unique length-100 sequences of permutations in both S3 and S5, splitting the data 90/10 for training/analysis. They fine-tune Pythia-160M models (Biderman et al., 2023) to predict the state corresponding to each prefix of each action sequence, using a cross-entropy loss:
C=−t=0∑99logPLM(St∣a0...at)
- C: Cross-entropy loss
- PLM(St∣a0...at): Probability the LLM places on state token Sn when conditioned on the length-n prefix of the document.
Models are fine-tuned for 20 epochs using the AdamW optimizer with a learning rate of 5e-5 and a batch size of 128. Larger models (above 700M parameters) are trained using bfloat16.
Activation Patching
Activation patching results exhibit two clusters of behavior, matching either the AA or PAA signature. Prototypical models on both S3 and S5 show this behavior. Patching intermediate representations of PAA-type models results in predictions with incorrect parity.
Probing
Test set accuracies of linear probes across LM layers l are consistent with the signatures predicted by AA and PAA. Models with AA-type probing signatures also have AA-type patching signatures, and vice versa. The paper refers to models as "AA-type" or "PAA-type" based on the cluster of signatures they exhibit. Probe accuracies broken down by sequence length confirm that models solve exponentially longer sequences at deeper layers.
For models that learn PAA on S3, representations of the final product can be geometrically decomposed into orthogonal directions, corresponding to the parity of the product and the cluster identity of the product.
Generalization by Sequence Length
AA- and PAA-type models were evaluated on their ability to generalize to sequences of varying lengths, assessing state accuracy and parity accuracy. Models generalize perfectly to sequences up to a cutoff length, followed by a steep accuracy drop-off.
For PAA models, the parity accuracy cutoff length is longer than the state accuracy cutoff length, whereas for AA models, the cutoff lengths are equal. AA models tend to generalize better overall.
Attention Patterns
Attention patterns differentiate between PAA and AA models. In early layers, PAA models exhibit parity heads, which attend to odd-parity actions. The parity of the state can be determined by counting odd-parity actions. No parity heads were found in AA models. Attention patterns in AA models sparsify in later layers, forming a tree-like pattern.
Factors Influencing Mechanism Choice
The paper studies the factors that determine which mechanism emerges during training.
Training Stage
An LM's eventual mechanism can be identified early in training based on prediction error patterns. AA models improve parity and state predictions in lockstep, while PAA models learn in two phases: converging on state parities and then accurately predicting the state.
Influential Factors
Whether an LM learns AA or PAA depends on model architecture, size, initialization scheme, and training data. Model architecture and initialization are bigger factors than model size. GPT-2 models split evenly between mechanisms, while Pythia models tend to learn AA when pre-trained and PAA when not.
Pretraining Effects
Intermediate tasks can encourage models to learn one mechanism or another. When from-scratch LMs are trained with a topic modeling next-token-prediction objective, they always learn an AA-type mechanism. Training LMs to predict state parity induces GPT-2 and Pythia models to learn PAA.
Conclusion
The paper demonstrates that LMs trained on permutation tracking tasks learn either an associative algorithm (AA) or a parity-associative algorithm (PAA). AA composes action subsequences in parallel, while PAA first computes a parity heuristic and then a parity complement. AA models generalize better and converge faster, with model architecture and pre-training influencing mechanism discovery.
The study acknowledges that while many state tracking tasks can be reduced to permutation tasks, the specific mechanisms LMs use for S5 may not generalize to other tasks, including those involving natural language.