Learning Cognitive Maps from Transformer Representations for Efficient Planning in Partially Observed Environments (2401.05946v1)
Abstract: Despite their stellar performance on a wide range of tasks, including in-context tasks only revealed during inference, vanilla transformers and variants trained for next-token predictions (a) do not learn an explicit world model of their environment which can be flexibly queried and (b) cannot be used for planning or navigation. In this paper, we consider partially observed environments (POEs), where an agent receives perceptually aliased observations as it navigates, which makes path planning hard. We introduce a transformer with (multiple) discrete bottleneck(s), TDB, whose latent codes learn a compressed representation of the history of observations and actions. After training a TDB to predict the future observation(s) given the history, we extract interpretable cognitive maps of the environment from its active bottleneck(s) indices. These maps are then paired with an external solver to solve (constrained) path planning problems. First, we show that a TDB trained on POEs (a) retains the near perfect predictive performance of a vanilla transformer or an LSTM while (b) solving shortest path problems exponentially faster. Second, a TDB extracts interpretable representations from text datasets, while reaching higher in-context accuracy than vanilla sequence models. Finally, in new POEs, a TDB (a) reaches near-perfect in-context accuracy, (b) learns accurate in-context cognitive maps (c) solves in-context path planning problems.
- Deepmind lab. arXiv preprint arXiv:1612.03801, 2016.
- Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
- Measuring disentanglement: A review of metrics. arXiv preprint arXiv:2012.09276, 2020.
- Decision transformer: Reinforcement learning via sequence modeling. Advances in neural information processing systems, 34:15084–15097, 2021a.
- Evaluating large language models trained on code. arXiv preprint arXiv:2107.03374, 2021b.
- Palm: Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311, 2022.
- Lonnie Chrisman. Reinforcement learning with perceptual aliasing: The perceptual distinctions approach. In AAAI, volume 1992, pages 183–188. Citeseer, 1992.
- Training verifiers to solve math word problems. arXiv preprint arXiv:2110.14168, 2021.
- Towards automated circuit discovery for mechanistic interpretability. arXiv preprint arXiv:2304.14997, 2023.
- Transformer-xl: Attentive language models beyond a fixed-length context. arXiv preprint arXiv:1901.02860, 2019.
- Maximum likelihood from incomplete data via the em algorithm. Journal of the royal statistical society: series B (methodological), 39(1):1–22, 1977.
- A mathematical framework for transformer circuits. Transformer Circuits Thread, 1, 2021.
- Clone-structured graph representations enable flexible learning and vicarious evaluation of cognitive maps. Nature communications, 12(1):2392, 2021.
- Factorial hidden markov models. Advances in Neural Information Processing Systems, 8, 1995.
- Bootstrap your own latent-a new approach to self-supervised learning. Advances in neural information processing systems, 33:21271–21284, 2020.
- Leveraging pre-trained large language models to construct and utilize world models for model-based task planning. arXiv preprint arXiv:2305.14909, 2023.
- Graph schemas as abstractions for transfer learning, inference, and planning. arXiv preprint arXiv:2302.07350, 2023.
- Byol-explore: Exploration by bootstrapped prediction. Advances in neural information processing systems, 35:31855–31870, 2022.
- Exploring network structure, dynamics, and function using networkx. Technical report, Los Alamos National Lab.(LANL), Los Alamos, NM (United States), 2008.
- Gaussian error linear units (gelus). arXiv preprint arXiv:1606.08415, 2016.
- Long short-term memory. Neural computation, 9(8):1735–1780, 1997.
- An algorithm for drawing general undirected graphs. Information processing letters, 31(1):7–15, 1989.
- Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
- Guaranteed discovery of control-endogenous latent states with multi-step inverse models. Transactions on Machine Learning Research, 2022.
- Emergent world representations: Exploring a sequence model trained on a synthetic task. arXiv preprint arXiv:2210.13382, 2022.
- Does circuit analysis interpretability scale? evidence from multiple choice capabilities in chinchilla. arXiv preprint arXiv:2307.09458, 2023.
- Gaia: a benchmark for general ai assistants. arXiv preprint arXiv:2311.12983, 2023.
- Evaluating cognitive maps and planning in large language models with cogeval. arXiv preprint arXiv:2309.15129, 2023.
- Progress measures for grokking via mechanistic interpretability. arXiv preprint arXiv:2301.05217, 2023.
- In-context learning and induction heads. arXiv preprint arXiv:2209.11895, 2022.
- Plansformer: Generating symbolic plans using transformers. arXiv preprint arXiv:2212.08681, 2022.
- Improving language understanding by generative pre-training. 2018.
- Zero-shot text-to-image generation. In International Conference on Machine Learning, pages 8821–8831. PMLR, 2021.
- A generalist agent. arXiv preprint arXiv:2205.06175, 2022.
- A distance measure between attributed relational graphs for pattern recognition. IEEE transactions on systems, man, and cybernetics, (3):353–362, 1983.
- Beyond the imitation game: Quantifying and extrapolating the capabilities of language models. arXiv preprint arXiv:2206.04615, 2022.
- Schema-learning and rebinding as mechanisms of in-context learning and emergence. arXiv preprint arXiv:2307.01201, 2023.
- Large language models still can’t plan (a benchmark for llms on planning and reasoning about change). arXiv preprint arXiv:2206.10498, 2022.
- Neural discrete representation learning. Advances in neural information processing systems, 30, 2017.
- Attention is all you need. Advances in neural information processing systems, 30, 2017.
- Emergent abilities of large language models. arXiv preprint arXiv:2206.07682, 2022.
- An explanation of in-context learning as implicit bayesian inference. arXiv preprint arXiv:2111.02080, 2021.