- The paper introduces control vectors added to the residual stream to steer LLM reasoning during inference.
- It employs methods including averaging, contrastive analysis, and PCA to derive effective vectors from model activations.
- Experiments on IOI, bAbI, and GSM8K demonstrate improved accuracy and cross-task generalization in reasoning.
This paper proposes a method to improve the reasoning performance of LLMs by applying interventions directly to their internal representations, specifically within the residual stream. The core idea is to derive "control vectors" from the model's activations when processing examples of successful reasoning and then add these vectors to the activations during inference to encourage a desired reasoning behavior. This approach falls under the domain of representation engineering, treating reasoning ability as a modifiable direction in the model's latent space.
The authors conceptualize the transformer architecture as a process where computational blocks (attention and MLP layers) read from and write to a residual stream. The hidden state vector at layer ℓ is denoted xℓ. Their intervention is modeled as a simple addition to the residual stream after the MLP block in layer ℓ, described by the equation:
xℓ+1=LayerNorm(yℓ+MLP(yℓ))+cℓ⋅α
where yℓ are the activations after the attention mechanism, cℓ is the layer-specific control vector, and α is a scalar that controls the magnitude and direction of the intervention.
The control vectors cℓ are derived from hidden state activations Hℓ(Pi) obtained by processing a set of training prompts P. Three methods for deriving cℓ are explored:
- Reading Vector: Averaging activations over a set of prompts: cℓ=∣P∣1i=1∑∣P∣Hℓ(Pi).
- Contrastive Reading Vector: Averaging the difference between activations from positive (P+) and negative (P−) prompt pairs: cℓ=∣P±∣1i=1∑∣P±∣(Hℓ(Pi+)−Hℓ(Pi−)). For reasoning tasks, positive examples are where the model reasoned correctly, and negative examples aim to capture representations of poor reasoning (tested schemes included incorrect model outputs and random character strings).
- PCA Contrastive Vector: Applying Principal Component Analysis to the set of difference vectors (Hℓ(Pi+)−Hℓ(Pi−)) and using the first principal component as the control vector. This is scaled to have a norm similar to the average activation norm for comparison.
The authors evaluate this method on three reasoning tasks:
- Indirect-Object-Identification (IOI): A simple inductive task involving identifying the indirect object in sentences like "Mary and John went to the store. John gave the groceries to Mary."
- bAbI Task 15: A deductive reasoning task requiring chaining facts from a short passage to answer a question.
- GSM8K: A dataset of grade school mathematical word problems requiring multi-step reasoning and calculation.
Experiments are conducted on Pythia-1.4B, Pythia-2.8B, and Mistral-7B-Instruct models. Control vectors are derived using examples from a training split of each dataset and applied only to the middle layer of the model at the final token. Performance is evaluated on a test set using logit-based accuracy (checking if the correct answer token has the highest logit among potential answers, instead of strict exact match) and analyzed using metrics like KL Divergence and Entropy of the logit distribution, and the average probability of correct vs. incorrect tokens as a function of α.
Key findings include:
- Applying control vectors can improve performance on the specified reasoning tasks across different models.
- The optimal scaling factor α varies by model and task, sometimes requiring a negative α value (e.g., for GSM8K on Mistral).
- For smaller Pythia models, slight accuracy improvements were observed on the IOI task, with metrics showing the intervention's effect on the logit distribution.
- For Mistral-7B-Instruct, improvements were seen on the bAbI and notably the more complex GSM8K tasks.
- A significant finding is that control vectors derived from one reasoning task (bAbI) can improve performance on a different reasoning task (GSM8K) and vice versa, suggesting the control vector captures a more general "reasoning" related direction in the model's latent space.
- Qualitative examples show that applying the intervention can influence the model's generated reasoning trace in GSM8K, leading to a correct answer where the original trace failed.
From an implementation perspective, the method requires:
- Access to model internals to extract hidden state activations. Libraries like HuggingFace Transformers can be used for loading models, and frameworks like PyTorch or TensorFlow allow accessing intermediate layer outputs.
- A dataset of task examples, ideally with corresponding correct outputs. For contrastive learning, examples where the model succeeds and fails are needed. The paper suggests using few-shot examples for both training vector derivation and inference to ground the model.
- Implementation of the control vector calculation logic (averaging, PCA on differences).
- A mechanism to modify the model's forward pass during inference to add the scaled control vector to the residual stream at specified layers and tokens. This might involve custom forward hooks or modifying the model's architecture definition.
- Evaluation scripts to calculate metrics like logit-based accuracy, KL divergence, and entropy.
The authors provide code publicly, which would be essential for reproducing and extending this work.
Practical considerations for applying this method:
- Computational Cost: Extracting activations requires multiple forward passes over the training data. Applying the control vector during inference adds a small computational overhead (vector addition). Deriving control vectors (especially PCA) might require storing and processing large matrices of activations.
- Data Requirements: The effectiveness depends on having a representative set of positive and negative examples to derive the control vector. Defining "unsuccessful reasoning" in practice can be challenging.
- Hyperparameter Tuning: The optimal α value needs to be determined, potentially requiring experimentation for each task and model. The choice of which layer(s) to intervene on (the paper focuses on the middle layer) and which token(s) might also require tuning.
- Robustness: The jagged trend lines in some results (e.g., Mistral GSM8K) suggest the intervention might be sensitive to the specific α value or examples.
- Generalization: While cross-task generalization was observed between bAbI and GSM8K, further research is needed to see how broadly this "reasoning" direction generalizes to other reasoning tasks or domains.
The paper concludes that reasoning performance can be modulated via representation engineering, suggesting that aspects of reasoning are encoded in the residual stream in a manner similar to other model characteristics like sentiment. While acknowledging limitations regarding model scale and task complexity studied, the results, particularly on GSM8K and the cross-task effect, are promising for future work on steerable and potentially more reliable LLMs without explicit training.