Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
119 tokens/sec
GPT-4o
56 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Transformer Learns Optimal Variable Selection in Group-Sparse Classification (2504.08638v1)

Published 11 Apr 2025 in stat.ML and cs.LG

Abstract: Transformers have demonstrated remarkable success across various applications. However, the success of transformers have not been understood in theory. In this work, we give a case study of how transformers can be trained to learn a classic statistical model with "group sparsity", where the input variables form multiple groups, and the label only depends on the variables from one of the groups. We theoretically demonstrate that, a one-layer transformer trained by gradient descent can correctly leverage the attention mechanism to select variables, disregarding irrelevant ones and focusing on those beneficial for classification. We also demonstrate that a well-pretrained one-layer transformer can be adapted to new downstream tasks to achieve good prediction accuracy with a limited number of samples. Our study sheds light on how transformers effectively learn structured data.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (3)
  1. Chenyang Zhang (25 papers)
  2. Xuran Meng (9 papers)
  3. Yuan Cao (201 papers)

Summary

This paper investigates how a standard one-layer Transformer architecture, trained with gradient descent, can learn to perform optimal variable selection in a specific type of classification task known as "group-sparse linear classification".

Problem Setting: Group Sparse Classification

  1. Data Generation: The input features $\Xb \in \mathbb{R}^{d \times D}$ consist of DD groups, each with dd features. Each group $\xb_j$ is drawn independently from a Gaussian distribution N(0,σx2Id)N(\mathbf{0}, \sigma_x^2 \mathbf{I}_d).
  2. Labeling: The true label y{1,+1}y \in \{-1, +1\} depends only on the features from a single, predefined "label-relevant" group jj^*. Specifically, $y = \text{sign}(\langle \xb_{j^*}, \vb^* \rangle)$, where $\vb^* \in \mathbb{R}^d$ is a fixed ground-truth direction vector (normalized to $\|\vb^*\|_2=1$). All other groups jjj \neq j^* are irrelevant to the label.
  3. Input Representation: Each feature group $\xb_j$ is concatenated with a unique positional encoding $\pb_j \in \mathbb{R}^D$ (using orthogonal sine functions) to form the input token $\zb_j = [\xb_j^\top, \pb_j^\top]^\top \in \mathbb{R}^{d+D}$. The full input sequence is $\Zb = [\zb_1, \dots, \zb_D]$.

Model and Training

  1. Architecture: A simplified single-head, one-layer self-attention Transformer is used: $f(\Zb, \Wb, \vb) = \sum_{j=1}^{D} \vb^\top \Zb \mathcal{S}(\Zb^\top \Wb \zb_j) = \vb^\top \Zb \mathbf{S} \mathbf{1}_D$.
    • $\Wb \in \mathbb{R}^{(d+D) \times (d+D)}$ is a trainable matrix combining query and key transformations.
    • $\vb \in \mathbb{R}^{d+D}$ is a trainable value vector.
    • S()\mathcal{S}(\cdot) is the softmax function applied column-wise, and SRD×D\mathbf{S} \in \mathbb{R}^{D \times D} is the resulting attention score matrix where Sj,j\mathbf{S}_{j', j} represents the attention paid by token jj to token jj'.
  2. Training: The model parameters $(\Wb, \vb)$ are trained jointly using gradient descent on the population cross-entropy loss $\mathcal{L}(\vb, \Wb) = \mathbb{E}_{(\Xb, y)\sim \mathcal{D}}[\log(1+\exp(-y \cdot f(\Zb, \Wb, \vb)))]$. Training starts from zero initialization ($\Wb^{(0)}=\mathbf{0}, \vb^{(0)}=\mathbf{0}$) with a shared learning rate η\eta.

Key Theoretical Findings (Theorem 3.1 & Section 5)

Under mild conditions (e.g., DD sufficiently large relative to the desired loss tolerance ϵ\epsilon), the paper proves that gradient descent successfully trains the Transformer to learn the group-sparse structure:

  1. Optimal Variable Selection via Attention: The attention mechanism learns to isolate the relevant group jj^*. After sufficient training iterations (TT^*), the attention scores satisfy Sj,j(T)1\mathbf{S}_{j^*, j}^{(T^*)} \approx 1 and Sj,j(T)0\mathbf{S}_{j', j}^{(T^*)} \approx 0 for jjj' \neq j^*, with high probability for any input $\Zb$. This means the model effectively "attends" only to the features from the correct group jj^*.
  2. Value Vector Alignment: The trainable value vector $\vb$ aligns correctly:
    • The first block $\vb_1 \in \mathbb{R}^d$ (corresponding to features) aligns its direction with the ground-truth vector $\vb^*$, i.e., $\vb_1^{(T^*)} / \|\vb_1^{(T^*)}\|_2 \approx \vb^*$.
    • The second block $\vb_2 \in \mathbb{R}^D$ (corresponding to positional encodings) remains approximately zero, $\vb_2^{(T^*)} \approx \mathbf{0}$. This ensures positional information is used for attention calculation but not directly included in the final output prediction.
  3. Loss Convergence: The population cross-entropy loss $\mathcal{L}(\vb^{(T^*)}, \Wb^{(T^*)})$ converges to be arbitrarily small (below ϵ\epsilon, bounded by 1/D21/D^2), with tight upper and lower bounds provided on the convergence rate.

Mechanism Explained (Proof Sketch - Section 5)

The paper provides insights into how this learning happens by analyzing the structure of the learned weight matrix $\Wb^{(T^*)}$:

  1. Learned $\Wb$ Structure: Gradient descent drives $\Wb$ towards a specific block structure:
    • $\Wb_{1,2}^{(T^*)} \approx \mathbf{0}$ and $\Wb_{2,1}^{(T^*)} \approx \mathbf{0}$: No direct interaction between feature vectors and positional encodings in the attention score calculation.
    • $\Wb_{1,1}^{(T^*)}$ (feature-feature interaction) aligns primarily with $\vb^* \vb^{*\top}$.
    • $\Wb_{2,2}^{(T^*)}$ (position-position interaction) develops a specific low-rank structure related to $(\pb_{j^*} - \pb_j)$ terms.
  2. Attention Focus: Due to the orthogonality of the chosen positional encodings $\pb_j$ and the learned structure of $\Wb_{2,2}^{(T^*)}$, the position-position interaction term $\pb_{j'}^\top \Wb_{2,2}^{(T^*)} \pb_j$ becomes significantly larger when j=jj'=j^* compared to other jj'. This term dominates the softmax calculation, causing Sj,j(T)\mathbf{S}_{j^*, j}^{(T^*)} to approach 1. The feature interaction term $\xb_{j'}^\top \Wb_{1,1}^{(T^*)} \xb_j$ is shown to be smaller in magnitude.

Transfer Learning Application (Section 4, Theorem 4.1)

The paper demonstrates the practical benefit of this learned structure for transfer learning:

  1. Downstream Task: Consider a new classification task with the same group sparsity pattern (jj^* is the same) but potentially different data distribution (sub-Gaussian features, linear separability margin γ\gamma).
  2. Fine-tuning: Initialize a new model with the pre-trained $\Wb^{(T^*)}$ (denoted $\tilde{\Wb}^{(0)}$) and $\tilde{\vb}^{(0)}=\mathbf{0}$. Fine-tune using online SGD on nn samples from the downstream task.
  3. Improved Sample Complexity: The average prediction error on the downstream task is bounded by O~(d+Dnγ2)\tilde{O}\left(\frac{d+D}{n\gamma^2}\right). This is significantly better than the Ω(dDnγ2)\Omega\left(\frac{dD}{n\gamma^2}\right) sample complexity required for a standard linear classifier trained on the full d×Dd \times D vectorized features, demonstrating the efficiency gained by reusing the learned variable selection mechanism.

Implementation Considerations & Practical Implications

  • Architecture Choice: The simplified one-layer architecture with a combined QK matrix and a value vector is amenable to theoretical analysis. Real-world applications might use standard Transformer blocks, but the core principle of attention learning structure could still apply.
  • Initialization: Zero initialization is crucial for the theoretical analysis. Practical Transformers often use specific initialization schemes (like Xavier/He), but the paper shows learning is possible from zero.
  • Optimization: Gradient descent is shown to work. The analysis uses population loss (infinite data), but experiments show similar behavior with SGD on finite datasets.
  • Positional Encoding: The specific orthogonal sinusoidal encoding facilitates the analysis. Other encodings might work but could change the learned structure of $\Wb_{2,2}$.
  • Benefit of Pre-training: The transfer learning result highlights a key benefit: pre-training on tasks with inherent structure (like group sparsity) allows the Transformer to learn efficient representations (focusing attention) that significantly accelerate learning on similar downstream tasks, requiring fewer samples.
  • Variable Selection: This work provides theoretical backing for the intuition that attention mechanisms can perform feature/variable selection, identifying and focusing on the most relevant parts of the input sequence.

Experiments

Numerical experiments on synthetic data confirm the theoretical predictions: loss converges, the value vector aligns, and the attention matrix correctly focuses on the jj^*-th group. Experiments on a modified CIFAR-10 task (embedding real images into noisy patches) further validate that the mechanism works on more complex, real-world-like data and that the model can identify the correct patch (jj^*) containing the true image, achieving good classification accuracy. High-dimensional experiments (d=100,D=100d=100, D=100) also show successful learning and attention focusing.

X Twitter Logo Streamline Icon: https://streamlinehq.com