This paper investigates whether explicitly inferring task-relevant latent variables improves the in-context learning (ICL) capabilities and out-of-distribution (OOD) generalization of Transformer models. The authors hypothesize that standard Transformers might rely on "non-parametric shortcuts" (e.g., kernel-like mechanisms) rather than truly understanding the underlying task structure by inferring its generative latents. This reliance on shortcuts could hinder generalization.
To test this, they compare two types of models:
- Implicit Model: A standard Transformer architecture where the query can directly attend to all context examples. This model jointly performs context aggregation and predictive modeling.
- Explicit Model: A modified Transformer architecture designed to enforce the inference of task latents. It separates context aggregation from prediction. A context Transformer processes the in-context examples (Dcontext={(xi,yi)}i=1n) to produce a fixed-size, low-dimensional bottleneck representation (zψ(Dcontext)). This bottleneck is then fed, along with the query x∗, to a separate prediction network (either an MLP or another Transformer) to predict y∗. The key idea is that this bottleneck forces the model to summarize the context into essential latent information, preventing the query from directly using context exemplars for shortcuts.
The paper systematically compares these models across various ICL tasks, including:
- Synthetic Regression: Linear, nonlinear (MLP), sinusoidal, Gaussian Process (GP), and Hodgkin-Huxley ODE prediction.
- Synthetic Classification: Linear and nonlinear (MLP).
- Reasoning-based/Compositional Tasks: Raven's Progressive Matrices, Alchemy, and Gene Targeting.
Models are trained from scratch and evaluated on both in-distribution (ID) and out-of-distribution (OOD) data. OOD evaluation for synthetic tasks involves querying points from a distribution with a higher standard deviation than seen during training. For reasoning tasks, OOD evaluation tests generalization to novel combinations of latent variable components (compositional generalization).
Key findings include:
- No General Performance Benefit from Explicit Models: Contrary to the initial hypothesis, explicit models generally do not outperform implicit models, either ID or OOD. In some synthetic regression tasks, the implicit model even showed slightly better OOD performance. Both model types struggled with OOD generalization on several tasks, while on others (classification, compositional tasks), both generalized reasonably well but with similar performance.
- Explicit Models Learn Correct Latent Variables: Despite the lack of performance improvement, investigations revealed that the bottleneck in the explicit model effectively learns to extract the true task-relevant latent variables. This was confirmed by successfully decoding these latents from the bottleneck representation using a linear probe, even without an auxiliary loss explicitly encouraging latent prediction.
- Prediction Function is a Key Bottleneck: The primary reason for the explicit model's lack of superior OOD performance appears to be the inability of its prediction function to effectively utilize the correctly inferred latent variables. When the learned prediction function in the explicit model was replaced with an "oracle" (the ground-truth generative function g), OOD performance significantly improved on most tasks. This suggests that simply inferring the correct latents is insufficient; the model also needs the correct mechanism or strong inductive biases to use these latents for prediction.
- Explicit Models Offer Better Interpretability: The bottleneck in explicit models provides a clear locus for inspecting learned task representations. The authors successfully decoded true latent variables from this bottleneck and demonstrated that manipulating these bottleneck activations led to correct counterfactual predictions (using Distributed Alignment Search - DAS), a property not easily achieved with implicit models.
- Scaling Trends: Performance for both implicit and explicit models scaled similarly with task difficulty (input dimensionality, context length) and model size. Implicit models consistently (though slightly) outperformed explicit models unless the explicit model used the oracle prediction function. Latent variable decoding accuracy in the explicit model improved with less uncertainty (lower data dimensionality, longer context) and larger model capacity.
The authors conclude that while biasing Transformers towards inferring task-relevant latent variables (e.g., via a bottleneck) aids interpretability, it is not sufficient to improve ICL generalization. The limitation seems to lie more fundamentally in the Transformer's ability to learn the correct prediction function to leverage these latents, rather than solely in its tendency to take non-parametric shortcuts. The paper suggests that future work should focus on incorporating stronger inductive biases into the prediction mechanism to better utilize inferred latents, potentially drawing from amortized inference methods or neurosymbolic AI. The paper also highlights the value of controlled experiments on simpler, well-understood tasks for dissecting complex phenomena like ICL.