Papers
Topics
Authors
Recent
Search
2000 character limit reached

(How) Do Language Models Track State?

Published 4 Mar 2025 in cs.CL, cs.AI, and cs.LG | (2503.02854v2)

Abstract: Transformer LMs exhibit behaviors -- from storytelling to code generation -- that appear to require tracking the unobserved state of an evolving world. How do they do so? We study state tracking in LMs trained or fine-tuned to compose permutations (i.e., to compute the order of a set of objects after a sequence of swaps). Despite the simple algebraic structure of this problem, many other tasks (e.g., simulation of finite automata and evaluation of boolean expressions) can be reduced to permutation composition, making it a natural model for state tracking in general. We show that LMs consistently learn one of two state tracking mechanisms for this task. The first closely resembles the "associative scan" construction used in recent theoretical work by Liu et al. (2023) and Merrill et al. (2024). The second uses an easy-to-compute feature (permutation parity) to partially prune the space of outputs, then refines this with an associative scan. The two mechanisms exhibit markedly different robustness properties, and we show how to steer LMs toward one or the other with intermediate training tasks that encourage or suppress the heuristics. Our results demonstrate that transformer LMs, whether pretrained or fine-tuned, can learn to implement efficient and interpretable state tracking mechanisms, and the emergence of these mechanisms can be predicted and controlled.

Summary

  • The paper investigates how Transformer language models track state in permutation tasks, finding they learn specific algorithms like Associative (AA) or Parity-Associative (PAA) rather than step-by-step simulation.
  • The study identifies two key mechanisms: Associative (AA) for parallel composition, and Parity-Associative (PAA) which uses a parity heuristic.
  • Mechanism choice depends on architecture and pre-training, with Associative models generalizing better, though findings may not extend beyond permutation tasks.

The paper investigates how Transformer LMs track the state of an evolving world, focusing on the task of permutation composition. The central question is whether LMs simulate state evolution step-by-step, use heuristics, or if state tracking is an illusion. The study uses permutation composition as a model system, where LMs predict the final position of objects after a series of swaps. This task is relevant because many state tracking tasks can be reduced to permutation tracking.

The authors analyze LMs by:

  • Introducing state tracking problems and the permutation composition task.
  • Reviewing interpretability tools for analyzing LM computations.
  • Describing potential algorithms LMs might use and their expected signatures.
  • Presenting experimental findings on state tracking mechanisms.

The paper identifies two state tracking mechanisms that LMs consistently learn:

  1. Associative Algorithm (AA): Resembles the associative scan construction.
  2. Parity-Associative Algorithm (PAA): Uses permutation parity to prune the output space, refining it with an associative scan.

The study finds that LMs do not typically use step-by-step simulation or fully parallel composition mechanisms, even when theoretically implementable. The findings are supported by evidence from representation interventions, probes, error patterns, attention maps, and training dynamics. The choice of mechanism affects model performance on long sequences, and while stochastic, can be influenced by intermediate training tasks that encourage or suppress parity heuristics. The paper suggests a mechanism by which real-world LMs might perform state tracking in language, code, and games.

Background and Preliminaries

The paper addresses the problem of state tracking, which involves inferring common ground, navigating environments, reasoning about code, and playing games. Prior research has explored whether and how LMs perform these tasks, with theoretical work noting the association of state-tracking problems with the complexity class NC1 and empirical work showing that LMs can solve these problems and encode state information.

The study focuses on permutation composition, where LMs are presented with objects and reshuffling operations, requiring them to compute the final order. Permutation tracking is NC1-complete, making it a suitable model for studying state tracking. The finite symmetric group SnS_n comprises permutations of nn objects with a composition operation. Every permutation can be expressed as a composition of two-element swaps, and the parity (even or odd) is determined by the number of swaps.

Given a sequence of permutations, the parity of the ttht^{th} permutation is denoted as ϵ(at)\epsilon(a_t), and the parity of the state StS_t is given by:

ϵ(St)=ϵ(a0...at)=iϵ(ai)mod2\epsilon(S_t) = \epsilon(a_0 ... a_t) = \sum_{i}\epsilon(a_i) \mod 2

  • ϵ(St)\epsilon(S_t): Parity of the state StS_t
  • ϵ(ai)\epsilon(a_i): Parity of the ithi^{th} permutation
  • a0...ata_0 ... a_t: Sequence of permutations
  • i\sum_{i}: Summation over the index i
  • modmod : Modulo operation

The LMs are trained to solve the word problem, taking a sequence of actions [a0,...,at][a_0, ..., a_t] as input and outputting a sequence of state predictions [s0,...,st][s_0, ..., s_t].

Interpretability Methods

The experiments use interpretability techniques to understand how LMs solve permutation word problems. The internal LM representation at token position tt after Transformer layer ll is denoted as ht,lh_{t,l}, with TT and LL representing the maximum input length and number of layers, respectively.

Probing

In probing experiments, a smaller probe model (e.g., a linear classifier) maps LM hidden representations hh to quantities zz. The experiments evaluate whether the state sts_t and the final state parity are linearly encoded in intermediate-layer representations. For each layer ll, a state probe predicts p(stht,l)p(s_t | h_{t,l}), and a parity probe predicts p(ϵ(St)ht,l)p(\epsilon(S_t) | h_{t,l}).

Activation Patching

Activation patching determines which representations play a causal role in prediction. Portions of the LM's internal representations are overwritten ("patched") with representations from alternative inputs. The probability that an LM assigns to the output yy given an input xx, with the representation hh replaced by some other representation hh', is denoted as p(yx;hh)p(y | x; h \leftarrow h').

The normalized logit difference (NLD) is used to measure the shift in prediction toward the clean output yy:

NLD=LD(x;ht,lhclean)LD(x)LD(x)LD(x)NLD = \frac{LD(x'; h_{t,l} \leftarrow h_{clean}) - LD(x')}{LD(x) - LD(x')}

where

LD(.)=logp(y.)logp(y.)LD(.) = \log p(y | .) - \log p(y' | .)

  • NLDNLD: Normalized Logit Difference
  • LDLD: Logit Difference
  • xx': Corrupted input
  • xx: Clean input
  • yy: Output from clean input
  • yy': Output from corrupted input
  • ht,lh_{t,l}: Internal LM representation at token position t after Transformer layer l
  • hcleanh_{clean}: Hidden representation from the clean input x

In prefix patching, all hidden representations up to index tt are patched at a layer ll.

Theoretical Algorithms

The paper establishes a phenomenology for LM state tracking by identifying candidate algorithms and their empirical signatures.

Sequential Algorithm

The sequential algorithm composes permutations one at a time from left to right, analogous to step-by-step simulation. In this algorithm, each hidden state ht,lh_{t,l} stores the associated action ata_t until StS_t can be computed, maintaining ht,t=Sth_{t,t} = S_t.

The patching signature for the sequential algorithm is upper triangular, as any patching experiment that replaces hidden states with l>tl > t will not affect predictions. The probing signature shows a linear dependence on depth, with the state probe accuracy increasing linearly and parity probe accuracy either increasing linearly or remaining constant.

Parallel Algorithm

The word problem on S5S_5 belongs to NC1, requiring circuit depth that scales logarithmically with sequence length, while the word problem on S3S_3 belongs to TC0TC^0, with constant-depth threshold circuits.

The patching signature for the parallel algorithm is L-shaped, with interventions at or earlier than layer lpl_p (number of layers) changing predictions, and interventions at deeper layers having no effect. The probing signature shows the probe obtaining perfect accuracy within a constant number of layers, with state parity also computed as an intermediate quantity.

Associative Algorithm (AA)

In the associative algorithm (AA), Transformers compose permutations hierarchically, grouping adjacent sequences and computing their product. This algorithm ensures that ht,l=at2l+1...ath_{t,l} = a_{t-2^{l+1}} ... a_t, and thus ht,log(t+1)=a0...ath_{t,log(t+1)} = a_0 ... a_t.

The patching signature for AA shows that the length of the prefix that must be modified to alter behavior increases exponentially with depth. The probing signature shows an exponentially increasing state probe accuracy, with parity probe accuracy also increasing exponentially if state parity is encoded in state representations.

Parity-Associative Algorithm (PAA)

The parity-associative algorithm (PAA) computes the final state in two stages: computing the parity of the state and computing the parity complement. Hidden states comprise registers ϵ\epsilon and KK, storing the parity and complement, respectively, i.e., ht,l=(Et,l,Kt,l)h_{t,l} = (E_{t,l}, K_{t,l}).

The patching signature for PAA depends on whether the corrupted input has the same parity as the clean input. If the parities differ, prefix patching shows a parallel algorithm signature. If the parities are the same, the patching pattern is AA-like. The probing signature shows state probes improving exponentially with depth, while parity probes converge to 100% at a constant depth.

Experimental Results

The paper compares theoretical mechanisms to empirical properties of LMs trained for permutation tasks, emphasizing that the signatures provide necessary but not sufficient conditions for implementation. Experiments identify algorithmic features shared between idealized mechanisms and transformer behavior, yielding evidence consistent with AA in some models and PAA in others.

Experimental Setup

The authors generated 1 million unique length-100 sequences of permutations in both S3S_3 and S5S_5, splitting the data 90/10 for training/analysis. They fine-tune Pythia-160M models (Biderman et al., 2023) to predict the state corresponding to each prefix of each action sequence, using a cross-entropy loss:

C=t=099logPLM(Sta0...at)C = - \sum_{t=0}^{99} \log P_{LM}(S_t | a_0 ... a_t)

  • CC: Cross-entropy loss
  • PLM(Sta0...at)P_{LM}(S_t | a_0 ... a_t): Probability the LLM places on state token SnS_n when conditioned on the length-nn prefix of the document.

Models are fine-tuned for 20 epochs using the AdamW optimizer with a learning rate of 5e-5 and a batch size of 128. Larger models (above 700M parameters) are trained using bfloat16.

Activation Patching

Activation patching results exhibit two clusters of behavior, matching either the AA or PAA signature. Prototypical models on both S3S_3 and S5S_5 show this behavior. Patching intermediate representations of PAA-type models results in predictions with incorrect parity.

Probing

Test set accuracies of linear probes across LM layers ll are consistent with the signatures predicted by AA and PAA. Models with AA-type probing signatures also have AA-type patching signatures, and vice versa. The paper refers to models as "AA-type" or "PAA-type" based on the cluster of signatures they exhibit. Probe accuracies broken down by sequence length confirm that models solve exponentially longer sequences at deeper layers.

For models that learn PAA on S3S_3, representations of the final product can be geometrically decomposed into orthogonal directions, corresponding to the parity of the product and the cluster identity of the product.

Generalization by Sequence Length

AA- and PAA-type models were evaluated on their ability to generalize to sequences of varying lengths, assessing state accuracy and parity accuracy. Models generalize perfectly to sequences up to a cutoff length, followed by a steep accuracy drop-off.

For PAA models, the parity accuracy cutoff length is longer than the state accuracy cutoff length, whereas for AA models, the cutoff lengths are equal. AA models tend to generalize better overall.

Attention Patterns

Attention patterns differentiate between PAA and AA models. In early layers, PAA models exhibit parity heads, which attend to odd-parity actions. The parity of the state can be determined by counting odd-parity actions. No parity heads were found in AA models. Attention patterns in AA models sparsify in later layers, forming a tree-like pattern.

Factors Influencing Mechanism Choice

The paper studies the factors that determine which mechanism emerges during training.

Training Stage

An LM's eventual mechanism can be identified early in training based on prediction error patterns. AA models improve parity and state predictions in lockstep, while PAA models learn in two phases: converging on state parities and then accurately predicting the state.

Influential Factors

Whether an LM learns AA or PAA depends on model architecture, size, initialization scheme, and training data. Model architecture and initialization are bigger factors than model size. GPT-2 models split evenly between mechanisms, while Pythia models tend to learn AA when pre-trained and PAA when not.

Pretraining Effects

Intermediate tasks can encourage models to learn one mechanism or another. When from-scratch LMs are trained with a topic modeling next-token-prediction objective, they always learn an AA-type mechanism. Training LMs to predict state parity induces GPT-2 and Pythia models to learn PAA.

Conclusion

The paper demonstrates that LMs trained on permutation tracking tasks learn either an associative algorithm (AA) or a parity-associative algorithm (PAA). AA composes action subsequences in parallel, while PAA first computes a parity heuristic and then a parity complement. AA models generalize better and converge faster, with model architecture and pre-training influencing mechanism discovery.

The study acknowledges that while many state tracking tasks can be reduced to permutation tasks, the specific mechanisms LMs use for S5S_5 may not generalize to other tasks, including those involving natural language.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Continue Learning

We haven't generated follow-up questions for this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 4 tweets with 20 likes about this paper.