An Analysis of Attention via the Lens of Exchangeability and Latent Variable Models (2212.14852v3)
Abstract: With the attention mechanism, transformers achieve significant empirical successes. Despite the intuitive understanding that transformers perform relational inference over long sequences to produce desirable representations, we lack a rigorous theory on how the attention mechanism achieves it. In particular, several intriguing questions remain open: (a) What makes a desirable representation? (b) How does the attention mechanism infer the desirable representation within the forward pass? (c) How does a pretraining procedure learn to infer the desirable representation through the backward pass? We observe that, as is the case in BERT and ViT, input tokens are often exchangeable since they already include positional encodings. The notion of exchangeability induces a latent variable model that is invariant to input sizes, which enables our theoretical analysis. - To answer (a) on representation, we establish the existence of a sufficient and minimal representation of input tokens. In particular, such a representation instantiates the posterior distribution of the latent variable given input tokens, which plays a central role in predicting output labels and solving downstream tasks. - To answer (b) on inference, we prove that attention with the desired parameter infers the latent posterior up to an approximation error, which is decreasing in input sizes. In detail, we quantify how attention approximates the conditional mean of the value given the key, which characterizes how it performs relational inference over long sequences. - To answer (c) on learning, we prove that both supervised and self-supervised objectives allow empirical risk minimization to learn the desired parameter up to a generalization error, which is independent of input sizes. Particularly, in the self-supervised setting, we identify a condition number that is pivotal to solving downstream tasks.
- Learning and generalization in overparameterized neural networks, going beyond two layers. In Neural Information Processing Systems.
- A convergence theory for deep learning via over-parameterization. In International Conference on Machine Learning.
- On the convergence rate of training recurrent neural networks. In Neural Information Processing Systems.
- Fine-grained analysis of optimization and generalization for overparameterized two-layer neural networks. In International Conference on Machine Learning.
- Bartlett, P. (1996). For valid generalization the size of the weights is more important than the size of the network. Neural Information Processing Systems.
- Gradient descent with identity initialization efficiently learns positive definite linear transformations by deep residual networks. In International Conference on Machine Learning.
- Representing smooth functions as compositions of near-identity functions with implications for deep network optimization. arXiv preprint arXiv:1804.05012.
- Spectrally-normalized margin bounds for neural networks. Neural Information Processing Systems.
- Rademacher and Gaussian complexities: Risk bounds and structural results. Journal of Machine Learning Research.
- Relational inductive biases, deep learning, and graph networks. arXiv preprint arXiv:1806.01261.
- Probabilistic symmetries and invariant neural networks. Journal of Machine Learning Research.
- Geometric deep learning: Grids, groups, graphs, geodesics, and gauges. arXiv preprint arXiv:2104.13478.
- Language models are few-shot learners. Neural Information Processing Systems.
- Generalization bounds of stochastic gradient descent for wide and deep neural networks. Neural Information Processing Systems.
- Optimal rates for the regularized least-squares algorithm. Foundations of Computational Mathematics.
- Decision transformer: Reinforcement learning via sequence modeling. Neural Information Processing Systems.
- On lazy training in differentiable programming. Neural Information Processing Systems.
- Transformer-XL: Attentive language models beyond a fixed-length context. arXiv preprint arXiv:1901.02860.
- de Finetti, B. (1937). La prévision: ses lois logiques, ses sources subjectives. In Annales de l’institut Henri Poincaré.
- On conditional density estimation. Statistica Neerlandica.
- BERT: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805.
- Finite exchangeable sequences. Annals of Probability.
- An image is worth 16×\times×16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.
- Gradient descent finds global minima of deep neural networks. In International Conference on Machine Learning.
- Gradient descent provably optimizes over-parameterized neural networks. arXiv preprint arXiv:1810.02054.
- Inductive biases and variable creation in self-attention mechanisms. arXiv preprint arXiv:2110.10090.
- Elesedy, B. (2021). Provably strict generalisation benefit for invariance in kernel methods. Neural Information Processing Systems.
- Fisher, R. A. (1922). On the mathematical foundations of theoretical statistics. Philosophical Transactions of the Royal Society A, 222 309–368.
- Fukumizu, K. (2015). Nonparametric bayesian inference with kernel mean embedding. In Modern Methodology and Applications in Spatial-Temporal Modeling. Springer, 1–24.
- What can transformers learn in-context? A case study of simple function classes. arXiv preprint arXiv:2208.01066.
- Geometrically equivariant graph neural networks: A survey. arXiv preprint arXiv:2202.07230.
- Identity matters in deep learning. arXiv preprint arXiv:1611.04231.
- Masked autoencoders are scalable vision learners. In Computer Vision and Pattern Recognition.
- Infinite attention: NNGP and NTK for deep attention networks. In International Conference on Machine Learning.
- Lietransformer: Equivariant self-attention for Lie groups. In International Conference on Machine Learning.
- Neural tangent kernel: Convergence and generalization in neural networks. Neural Information Processing Systems.
- Fantastic generalization measures and where to find them. arXiv preprint arXiv:1912.02178.
- Highly accurate protein structure prediction with AlphaFold. Nature.
- Universal invariant and equivariant graph neural networks. Neural Information Processing Systems.
- Self-attention between datapoints: Going beyond individual input-output pairs in deep learning. Neural Information Processing Systems.
- Probability in Banach Spaces: isoperimetry and processes, vol. 23. Springer Science & Business Media.
- Set transformer: A framework for attention-based permutation-invariant neural networks. In International Conference on Machine Learning.
- Learning overparameterized neural networks via stochastic gradient descent on structured data. Neural Information Processing Systems.
- A kernel-based view of language model fine-tuning. arXiv preprint arXiv:2210.05643.
- Maurer, A. (2016). A vector-contraction inequality for rademacher complexities. In International Conference on Algorithmic Learning Theory. Springer.
- Mean-field theory of two-layers neural networks: Dimension-free bounds and kernel limit. In Annual Conference on Learning Theory.
- A mean field view of the landscape of two-layer neural networks. Proceedings of the National Academy of Sciences.
- Foundations of machine learning. MIT press.
- Kernel mean embedding of distributions: A review and beyond. arXiv preprint arXiv:1605.09522.
- Nguyen, P.-M. (2019). Mean field limit of the learning dynamics of multilayer neural networks. arXiv preprint arXiv:1902.02880.
- Pearl, J. (2009). Causality. Cambridge University press.
- Improving language understanding by generative pre-training. Technical Report.
- Language models are unsupervised multitask learners. OpenAI blog.
- Zero-shot text-to-image generation. In International Conference on Machine Learning.
- Group equivariant stand-alone self-attention for vision. arXiv preprint arXiv:2010.00977.
- Parameters as interacting particles: Long time convergence and asymptotic error scaling of neural networks. Neural Information Processing Systems.
- Improved generalization bounds of group invariant/equivariant deep networks via quotient feature spaces. In Uncertainty in Artificial Intelligence.
- E(n)𝐸𝑛{E}(n)italic_E ( italic_n ) equivariant normalizing flows. arXiv preprint arXiv:2105.09016.
- The graph neural network model. IEEE Transactions on Neural Networks.
- A tutorial on Gaussian process regression: Modelling, exploring, and exploiting functions. Journal of Mathematical Psychology.
- Kernel methods for pattern analysis. Cambridge University Press.
- Kernel instrumental variable regression. Advances in Neural Information Processing Systems.
- Mean field analysis of neural networks: A central limit theorem. Stochastic Processes and Their Applications.
- Generalization error of invariant classifiers. In Artificial Intelligence and Statistics.
- Nonparametric estimation of multi-view latent variable models. In International Conference on Machine Learning.
- Hilbert space embeddings of conditional distributions with applications to dynamical systems. In International Conference on Machine Learning.
- Transformer dissection: A unified understanding of transformer’s attention via the lens of kernel. arXiv preprint arXiv:1908.11775.
- Generalization bounds for deep learning. arXiv preprint arXiv:2012.04115.
- Attention is all you need. In Neural Information Processing Systems.
- A mathematical theory of attention. arXiv preprint arXiv:2007.02876.
- Wainwright, M. J. (2019). High-dimensional statistics: A non-asymptotic viewpoint. Cambridge University Press.
- Wasserman, L. (2000). Bayesian model selection and model averaging. Journal of Mathematical Psychology, 44 92–107.
- Statistically meaningful approximation: A case study on approximating Turing machines with transformers. arXiv preprint arXiv:2107.13163.
- Transformers: State-of-the-art natural language processing. In Empirical Methods in Natural Language Processing.
- An explanation of in-context learning as implicit Bayesian inference. arXiv preprint arXiv:2111.02080.
- Yang, G. (2020). Tensor programs II: Neural tangent kernel for any architecture. arXiv preprint arXiv:2006.14548.
- Tensor programs IIb: Architectural universality of neural tangent kernel training dynamics. In International Conference on Machine Learning.
- Deep sets. Neural Information Processing Systems.
- Relational reasoning via set transformers: Provable efficiency and applications to MARL. arXiv preprint arXiv:2209.09845.
- Learning one-hidden-layer ReLU networks via gradient descent. In International Conference on Machine Learning.
- Unveiling transformers with LEGO: A synthetic reasoning task. arXiv preprint arXiv:2206.04301.
- Divide and conquer kernel ridge regression: A distributed algorithm with minimax optimal rates. The Journal of Machine Learning Research, 16 3299–3340.
- Understanding the generalization benefit of model invariance from a data perspective. Neural Information Processing Systems.
- Stochastic gradient descent optimizes over-parameterized deep ReLU networks. arXiv preprint arXiv:1811.08888.
- An improved analysis of training over-parameterized deep neural networks. Neural Information Processing Systems.
Sponsor
Paper Prompts
Sign up for free to create and run prompts on this paper using GPT-5.
Top Community Prompts
Collections
Sign up for free to add this paper to one or more collections.