- The paper shows that transformers can exactly represent k-fold compositions with O(log k) layers by hierarchically composing permutation functions.
- It reveals a statistical-computational gap where training solely on hard examples requires exponentially many samples, emphasizing the need for structured data.
- The work proposes both explicit and implicit curriculum strategies that enable efficient learning with polynomial sample complexity for complex reasoning tasks.
This paper, "Learning Compositional Functions with Transformers from Easy-to-Hard Data" (2505.23683), investigates the learnability of complex compositional tasks using transformers via gradient descent, focusing particularly on the role of the training data distribution. The authors define a synthetic task, the "k-fold composition task", which serves as a controlled environment to paper the challenges and opportunities in learning compositional functions.
The k-fold Composition Task
The core task involves computing the result of applying an interleaved sequence of k input-dependent permutations σi and k hidden (parametric) permutations πi to an initial element x. Formally, for an input (σ,x)∈(SN)k×[N], the task fπ(σ,x) computes (σk∘πk∘σk−1∘πk−1∘⋯∘σ1∘π1)(x). Here, SN is the set of permutations on N elements. The hidden permutations π=(π1,…,πk) are fixed for a given target function fπ but are unknown to the learner, representing "parametric knowledge". The input permutations σ=(σ1,…,σk) vary per instance, representing "contextual knowledge". This structure mirrors real-world reasoning problems like multi-hop question answering, where a model must combine information from the input context with knowledge stored in its parameters. A cyclic variant of the task is also introduced, where the starting permutation index i is also an input: $f_\pi^{\cyc}(\sigma, i, x) := (\sigma_{i + k - 1} \circ \pi_{i + k - 1} \circ \cdots \circ \sigma_i \circ \pi_i)(x \pmod k \text{ for indices})$.
Transformer Architecture and Expressivity
The paper considers an attention-only transformer architecture with L layers and embedding dimension d. The input (σ,x) is first embedded into a sequence of T=kN tokens. Each token corresponds to a pair (i,j)∈[k]×[N] and encodes information about the mapping σi(j). The theoretical construction in the paper uses a fixed embedding function ϕ(i,j,σi(j))∈Rd. The transformer processes this sequence, and a final readout layer decodes the output for the queried initial element x.
A key theoretical result (Theorem 1) shows that the k-fold composition task can be exactly expressed by a transformer with L=O(logk) layers. The construction works by having each layer compose the result of the previous layer's computation with the appropriate σi∘πi pair. This hierarchical composition allows the transformer to compute $2k$ total compositions in a logarithmic number of steps. The embedding dimension required for this construction is d=O~(Nk). This expressivity result highlights that transformers can represent such compositional tasks efficiently in terms of depth.
Statistical-Computational Gap
Despite the positive expressivity result, the paper identifies a significant challenge for learning this task. Using the Statistical Query (SQ) model, which provides a lower bound for a wide class of learning algorithms including gradient descent on neural networks, the authors prove a lower bound (Theorem 2). This bound shows that any SQ learner training on data consisting only of k-fold composition examples (the "hardest" task) requires either a number of queries or a tolerance level that grows exponentially in k (NΩ(k) samples or query tolerance τ≤N−Ω(k)).
The practical implication is that training a transformer via standard gradient descent solely on hard k-fold examples is likely to be computationally intractable, requiring an exponential number of samples or huge computation, despite the function being representable by a relatively shallow network. This points to a statistical-computational gap specific to training on the hardest instances.
Efficient Learning via Easy-to-Hard Data
The paper then demonstrates that this computational barrier can be overcome by training on data that exposes the compositional structure, specifically by including "easy" examples. The authors propose two strategies that allow efficient learning with polynomial sample complexity ($\poly(N, k)$ samples):
- Explicit Curriculum Learning (Theorem 3): The model is trained in stages. In stage ℓ, the model is trained on examples of the 2ℓ−1-fold composition task (i.e., computing $\hop_i^{2^{\ell-1}} = \sigma_{i+2^{\ell-1}-1} \circ \pi_{i+2^{\ell-1}-1} \circ \cdots \circ \sigma_i \circ \pi_i$). The theoretical analysis shows that gradient descent successfully learns the layers sequentially. Learning the 2ℓ−1-hop task effectively trains the ℓ-th layer of the transformer's construction. After logk stages, the model learns the full k-fold composition.
- Implementation: This involves generating data for intermediate hop lengths (1,2,4,…,k) and training the model sequentially on these datasets. The loss function would be adapted in each stage to target the corresponding hop length.
- Implicit Curriculum via Data Mixture (Theorem 4): A potentially simpler approach in practice is to train the model simultaneously on a mixture of data from all difficulty levels ($1$-fold, $2$-fold, $4$-fold, …,k-fold tasks). The theoretical analysis shows that this mixture also induces an "implicit curriculum" for gradient descent. The gradients for layers responsible for simpler compositions (fewer hops) naturally become dominant early in training, allowing these layers to be learned first, which then enables learning of more complex compositions by deeper layers.
- Implementation: This involves generating a single dataset containing examples from various hop lengths. The loss function would be a sum of losses for predicting different hops for each input, potentially weighted.
Both strategies show that presenting data with increasing or mixed difficulty allows gradient descent to leverage the underlying compositional structure and efficiently learn the task, avoiding the exponential cost associated with training only on the hardest k-fold instances.
Practical Implications and Considerations
- Data Distribution is Crucial: The core takeaway is that for compositional tasks, the distribution of training data can be as important as the model architecture itself. Training on appropriately structured data (easy-to-hard) can unlock efficient learning for functions that are otherwise computationally hard to learn from only the final task output.
- Curriculum vs. Mixture: While explicit curriculum provides a clear theoretical path, mixing data offers a simpler practical implementation strategy that achieves similar theoretical benefits. This aligns with empirical findings in training LLMs on diverse tasks.
- Architecture Details: The theoretical model is simplified (attention-only, fixed value matrices after initialization). Real-world transformers are more complex (MLPs, layer normalization, learned embeddings). Experiments show the phenomenon holds for standard architectures. The theoretical embedding dimension is high (O(Nklogk)), which might be a limitation for large N and k, but the authors suggest trainable embeddings might mitigate this in practice.
- Task Design: The k-fold task is synthetic but captures the essence of composing structured knowledge. Designing similar synthetic tasks or analyzing real-world datasets for their compositional structure could guide data creation for training models on complex reasoning.
- Limitations: The analysis focuses on a specific synthetic task and a simplified transformer. Generalizing these results to arbitrary compositional functions and full-scale transformers is an area for future work. Learning the value matrices (fixed in the main theoretical analysis) is also noted as an important future direction.
In summary, this paper provides theoretical evidence that transformers can efficiently represent compositional functions with O(logk) layers, but learning these functions requires training data that exposes the hierarchical structure of the composition, either through an explicit curriculum or by mixing data of varying difficulty. This highlights the critical role of data distribution in enabling efficient gradient-based learning for complex reasoning tasks.