Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
113 tokens/sec
GPT-4o
12 tokens/sec
Gemini 2.5 Pro Pro
47 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
4 tokens/sec
DeepSeek R1 via Azure Pro
33 tokens/sec
2000 character limit reached

Learning Compositional Functions with Transformers from Easy-to-Hard Data (2505.23683v1)

Published 29 May 2025 in cs.LG

Abstract: Transformer-based LLMs have demonstrated impressive capabilities across a range of complex reasoning tasks. Prior theoretical work exploring the expressive power of transformers has shown that they can efficiently perform multi-step reasoning tasks involving parallelizable computations. However, the learnability of such constructions, particularly the conditions on the data distribution that enable efficient learning via gradient-based optimization, remains an open question. Towards answering this question, in this work we study the learnability of the $k$-fold composition task, which requires computing an interleaved composition of $k$ input permutations and $k$ hidden permutations, and can be expressed by a transformer with $O(\log k)$ layers. On the negative front, we prove a Statistical Query (SQ) lower bound showing that any SQ learner that makes only polynomially-many queries to an SQ oracle for the $k$-fold composition task distribution must have sample size exponential in $k$, thus establishing a statistical-computational gap. On the other hand, we show that this function class can be efficiently learned, with runtime and sample complexity polynomial in $k$, by gradient descent on an $O(\log k)$-depth transformer via two different curriculum learning strategies: one in which data consists of $k'$-fold composition functions with $k' \le k$ presented in increasing difficulty, and another in which all such data is presented simultaneously. Our work sheds light on the necessity and sufficiency of having both easy and hard examples in the data distribution for transformers to learn complex compositional tasks.

Summary

  • The paper shows that transformers can exactly represent k-fold compositions with O(log k) layers by hierarchically composing permutation functions.
  • It reveals a statistical-computational gap where training solely on hard examples requires exponentially many samples, emphasizing the need for structured data.
  • The work proposes both explicit and implicit curriculum strategies that enable efficient learning with polynomial sample complexity for complex reasoning tasks.

This paper, "Learning Compositional Functions with Transformers from Easy-to-Hard Data" (2505.23683), investigates the learnability of complex compositional tasks using transformers via gradient descent, focusing particularly on the role of the training data distribution. The authors define a synthetic task, the "kk-fold composition task", which serves as a controlled environment to paper the challenges and opportunities in learning compositional functions.

The kk-fold Composition Task

The core task involves computing the result of applying an interleaved sequence of kk input-dependent permutations σi\sigma_i and kk hidden (parametric) permutations πi\pi_i to an initial element xx. Formally, for an input (σ,x)(SN)k×[N](\sigma, x) \in (S_N)^k \times [N], the task fπ(σ,x)f_\pi(\sigma, x) computes (σkπkσk1πk1σ1π1)(x)(\sigma_k \circ \pi_k \circ \sigma_{k-1} \circ \pi_{k-1} \circ \cdots \circ \sigma_1 \circ \pi_1)(x). Here, SNS_N is the set of permutations on NN elements. The hidden permutations π=(π1,,πk)\pi = (\pi_1, \dots, \pi_k) are fixed for a given target function fπf_\pi but are unknown to the learner, representing "parametric knowledge". The input permutations σ=(σ1,,σk)\sigma = (\sigma_1, \dots, \sigma_k) vary per instance, representing "contextual knowledge". This structure mirrors real-world reasoning problems like multi-hop question answering, where a model must combine information from the input context with knowledge stored in its parameters. A cyclic variant of the task is also introduced, where the starting permutation index ii is also an input: $f_\pi^{\cyc}(\sigma, i, x) := (\sigma_{i + k - 1} \circ \pi_{i + k - 1} \circ \cdots \circ \sigma_i \circ \pi_i)(x \pmod k \text{ for indices})$.

Transformer Architecture and Expressivity

The paper considers an attention-only transformer architecture with LL layers and embedding dimension dd. The input (σ,x)(\sigma, x) is first embedded into a sequence of T=kNT=kN tokens. Each token corresponds to a pair (i,j)[k]×[N](i, j) \in [k] \times [N] and encodes information about the mapping σi(j)\sigma_i(j). The theoretical construction in the paper uses a fixed embedding function ϕ(i,j,σi(j))Rd\phi(i, j, \sigma_i(j)) \in \mathbb{R}^d. The transformer processes this sequence, and a final readout layer decodes the output for the queried initial element xx.

A key theoretical result (Theorem 1) shows that the kk-fold composition task can be exactly expressed by a transformer with L=O(logk)L = O(\log k) layers. The construction works by having each layer compose the result of the previous layer's computation with the appropriate σiπi\sigma_i \circ \pi_i pair. This hierarchical composition allows the transformer to compute $2k$ total compositions in a logarithmic number of steps. The embedding dimension required for this construction is d=O~(Nk)d = \tilde O(Nk). This expressivity result highlights that transformers can represent such compositional tasks efficiently in terms of depth.

Statistical-Computational Gap

Despite the positive expressivity result, the paper identifies a significant challenge for learning this task. Using the Statistical Query (SQ) model, which provides a lower bound for a wide class of learning algorithms including gradient descent on neural networks, the authors prove a lower bound (Theorem 2). This bound shows that any SQ learner training on data consisting only of kk-fold composition examples (the "hardest" task) requires either a number of queries or a tolerance level that grows exponentially in kk (NΩ(k)N^{\Omega(k)} samples or query tolerance τNΩ(k)\tau \le N^{-\Omega(k)}).

The practical implication is that training a transformer via standard gradient descent solely on hard kk-fold examples is likely to be computationally intractable, requiring an exponential number of samples or huge computation, despite the function being representable by a relatively shallow network. This points to a statistical-computational gap specific to training on the hardest instances.

Efficient Learning via Easy-to-Hard Data

The paper then demonstrates that this computational barrier can be overcome by training on data that exposes the compositional structure, specifically by including "easy" examples. The authors propose two strategies that allow efficient learning with polynomial sample complexity ($\poly(N, k)$ samples):

  1. Explicit Curriculum Learning (Theorem 3): The model is trained in stages. In stage \ell, the model is trained on examples of the 212^{\ell-1}-fold composition task (i.e., computing $\hop_i^{2^{\ell-1}} = \sigma_{i+2^{\ell-1}-1} \circ \pi_{i+2^{\ell-1}-1} \circ \cdots \circ \sigma_i \circ \pi_i$). The theoretical analysis shows that gradient descent successfully learns the layers sequentially. Learning the 212^{\ell-1}-hop task effectively trains the \ell-th layer of the transformer's construction. After logk\log k stages, the model learns the full kk-fold composition.
    • Implementation: This involves generating data for intermediate hop lengths (1,2,4,,k1, 2, 4, \dots, k) and training the model sequentially on these datasets. The loss function would be adapted in each stage to target the corresponding hop length.
  2. Implicit Curriculum via Data Mixture (Theorem 4): A potentially simpler approach in practice is to train the model simultaneously on a mixture of data from all difficulty levels ($1$-fold, $2$-fold, $4$-fold, ,k\dots, k-fold tasks). The theoretical analysis shows that this mixture also induces an "implicit curriculum" for gradient descent. The gradients for layers responsible for simpler compositions (fewer hops) naturally become dominant early in training, allowing these layers to be learned first, which then enables learning of more complex compositions by deeper layers.
    • Implementation: This involves generating a single dataset containing examples from various hop lengths. The loss function would be a sum of losses for predicting different hops for each input, potentially weighted.

Both strategies show that presenting data with increasing or mixed difficulty allows gradient descent to leverage the underlying compositional structure and efficiently learn the task, avoiding the exponential cost associated with training only on the hardest kk-fold instances.

Practical Implications and Considerations

  • Data Distribution is Crucial: The core takeaway is that for compositional tasks, the distribution of training data can be as important as the model architecture itself. Training on appropriately structured data (easy-to-hard) can unlock efficient learning for functions that are otherwise computationally hard to learn from only the final task output.
  • Curriculum vs. Mixture: While explicit curriculum provides a clear theoretical path, mixing data offers a simpler practical implementation strategy that achieves similar theoretical benefits. This aligns with empirical findings in training LLMs on diverse tasks.
  • Architecture Details: The theoretical model is simplified (attention-only, fixed value matrices after initialization). Real-world transformers are more complex (MLPs, layer normalization, learned embeddings). Experiments show the phenomenon holds for standard architectures. The theoretical embedding dimension is high (O(Nklogk)O(Nk \log k)), which might be a limitation for large NN and kk, but the authors suggest trainable embeddings might mitigate this in practice.
  • Task Design: The kk-fold task is synthetic but captures the essence of composing structured knowledge. Designing similar synthetic tasks or analyzing real-world datasets for their compositional structure could guide data creation for training models on complex reasoning.
  • Limitations: The analysis focuses on a specific synthetic task and a simplified transformer. Generalizing these results to arbitrary compositional functions and full-scale transformers is an area for future work. Learning the value matrices (fixed in the main theoretical analysis) is also noted as an important future direction.

In summary, this paper provides theoretical evidence that transformers can efficiently represent compositional functions with O(logk)O(\log k) layers, but learning these functions requires training data that exposes the hierarchical structure of the composition, either through an explicit curriculum or by mixing data of varying difficulty. This highlights the critical role of data distribution in enabling efficient gradient-based learning for complex reasoning tasks.

Dice Question Streamline Icon: https://streamlinehq.com

Follow-up Questions

We haven't generated follow-up questions for this paper yet.