- The paper demonstrates that multi-task learning combined with a low-degree (low-complexity) bias enables neural networks to recover latent world models.
- The paper leverages Boolean Fourier-Walsh analysis to quantify representation complexity and establish conditions favoring shared latent recovery.
- The paper shows that bias in task selection and architecture compatibility yield improved out-of-distribution generalization in complex proxy tasks.
This paper, "When do Neural Networks Learn World Models?", investigates the conditions under which neural networks can recover latent data-generating variables, a process equated with learning "world models." The authors provide theoretical results suggesting that a combination of multi-task learning, a low-complexity bias (specifically a low-degree bias in Boolean function space), and compatible model architecture enables neural networks to learn these underlying structures, even when proxy tasks are complex, non-linear functions of the latents.
1. Formulation of Learning World Models
The paper defines learning world models as the recovery of latent variables (z) that generate observed data (x). The data generation process is defined as:
- Sample z∼p(z) from a latent space Z.
- Generate x=ψ(z) through an invertible, non-linear function ψ:Z→X.
A representation Φ:X→Z learns the world model up to a set of simple transforms T (e.g., linear transforms) if there exists T∈T such that Φ(x)=T(z) for x=ψ(z).
Models are trained on proxy tasks, h:X→Y. A realization f:X→Y of task h satisfies f(x)=h(x) for all observed data. The goal is to learn a hierarchical realization g∘Φ∈R(h) where Φ learns the world model.
A core challenge is the non-identifiability of latent variables: multiple solutions (different Φ′) can fit the observed data equally well, making it impossible to recover the true latents based on task performance alone. The paper proposes that the implicit bias of neural networks towards low-complexity solutions can resolve this ambiguity.
2. Complexity Measures via Boolean Models
To formalize "low-complexity," the paper models all variables (observed data x, latent variables z, and internal representations) as Boolean strings (e.g., x∈{−1,1}m,z∈{−1,1}d). This allows leveraging the Fourier-Walsh transform for Boolean functions:
Any function f:{±1}n→R can be uniquely expressed as f(x)=S⊆[n]∑f^(S)χS(x), where χS(x)=∏i∈Sxi are parity functions.
The complexity of a Boolean function f is measured by its degree: deg(f)=max{∣S∣:f^(S)=0}.
For hierarchical realizations h(1)∘⋯∘h(q), the realization degree is ∑i∈[q]deg(h(i)). Models with a low-degree bias are assumed to minimize this realization degree.
Other relevant definitions:
- Min-degree solutions H(h): Solutions in R(h) that minimize deg(f).
- Conditional degree deg(h∣Φ): deg(h∗)−max{deg(g):g∘Φ∈R(h)}, where h∗∈H(h). A positive conditional degree means Φ simplifies the task.
3. Theoretical Analysis
The paper presents several key theoretical findings:
- Theorem 3.1 (Single-task learning): For a single task h, a flat realization h∗∈H(h) always has a realization degree less than or equal to any hierarchical realization g∘Φ. This suggests that without further constraints, models prefer not to learn intermediate representations for single tasks.
- Implication: Learning representations often requires more than a single, isolated task.
- Theorem 3.2 (Multi-task learning): In a multi-task setting with n tasks (h1,…,hn), a hierarchical realization g∘Φ∗ (where g=(g1,…,gn) and Φ∗ is a min-degree realization of Φ) is favored over a set of independent flat realizations h∗=(h1∗,…,hn∗) if the sum of conditional degrees on Φ∗ is sufficiently large:
- Implication: Multi-task learning, like next-token prediction in LLMs, can drive the learning of shared, general-purpose representations if these representations simplify a sufficient number of tasks. Proxy tasks should be chosen such that conditioning on the desired representation Φ significantly reduces task complexity (i.e., deg(hi∣Φ)>0).
- Theorem 3.4 (Representational no free lunch): If proxy tasks are sampled uniformly from all possible functions of the true latents (Fd∘ψ−1), then as the number of tasks n→∞, all viable representations (those that can solve all tasks via some bijective transform T from the true latents) have the same task-averaged realization complexity.
- Implication: Without a bias in the task distribution, world models are only learnable up to arbitrary bijective (highly non-linear) transforms.
- k-degree tasks: Tasks h are k-degree if h∈Fkd∘ψ−1, meaning they can be solved by a function of degree ≤k on top of the true latents z.
- Theorem 3.7 (World model learning): If tasks are sampled such that lower-degree functions of the true latents are preferred (e.g., sampling from k-degree tasks with pk>0 for k<d and non-zero p1), then the representation Φ∗ minimizing the average realization degree will learn the world model up to negations and permutations of the true latent variables z.
- Implication: A bias in the task distribution towards simpler functions of the true latents is crucial. This result aligns with the "linear representation hypothesis" in LLMs, where interpretable features are often found as linear directions in activation space (permutations and negations are degree-1 Boolean functions, analogous to linear functions).
- Theorem 3.8 (Benefits of learning world models): In an out-of-distribution (OOD) generalization setting where training latents are from a Hamming ball Br (subset of {−1,1}d) and test latents are from the full space {−1,1}d, a model Φ∗ that recovers latents (as in Thm 3.7) achieves zero test MSE on a parity task h (where h∘ψ is a parity function of degree q), provided deg(h∣ψ−1) is large enough. In contrast, any flat realization h∗ has test MSE >1.
- Implication: World models offer provable OOD generalization benefits, especially when downstream tasks become simpler given the true latents.
- Theorem 3.10 (Impact of model architecture - Basis Compatibility): Model architecture influences how functions are represented. If the Fourier-Walsh basis χS is considered the "natural" basis, a model architecture can be seen as inducing a new basis via a transform U. If U is compatible (i.e., degU(f)=deg(f) because deg(U(χS))=deg(χS)), then Φ∗ still learns the world model up to negations and permutations. If U is incompatible, Φ∗ might recover a more complex, non-linear transformation of the true latents.
- Implication: The choice of model architecture (e.g., activation functions) should align with the "natural" complexity of the tasks and latent variables to facilitate world model learning. Architectures should be biased towards representing low-degree functions in the natural basis efficiently.
4. Algorithmic Implications and Experiments
The paper demonstrates the implications of basis compatibility with two experiments:
- Polynomial Extrapolation:
- Standard ReLU MLPs fail to extrapolate polynomials of degree >1 beyond the training region. This is attributed to the incompatibility of ReLU activations with the polynomial basis {1,x,x2,…}.
- Solution: An MLP where half of the ReLUs in each layer are replaced by identity (σ(x)=x) and quadratic (σ(x)=x2) activations.
- Result: This modified MLP significantly improves extrapolation performance for degree-2 and degree-3 polynomials, suggesting it learns a more compatible basis.
1
2
3
4
5
6
7
8
9
10
|
# Pseudocode for modified MLP layer
def modified_mlp_layer(input_tensor, weights, biases):
linear_output = input_tensor @ weights + biases
# Split channels for different activations
out_relu = relu(linear_output[:, :N_relu])
out_identity = linear_output[:, N_relu : N_relu+N_identity]
out_quadratic = linear_output[:, N_relu+N_identity:] ** 2
return concatenate(out_relu, out_identity, out_quadratic) |
- Learning Physical Laws:
- Tasks involved predicting single-object parabolic motion and two-object elastic collision motion, evaluated on OOD settings (different initial velocities/sizes).
- A Transformer model with its MLP components replaced by the modified MLP (with GELU instead of remaining ReLUs) was compared against a standard Transformer.
- Result: The modified Transformer achieved lower prediction error in OOD settings, suggesting it better captured the underlying physical laws.
5. Limitations and Future Work
- The model of representation Φ could be more structured to reflect hierarchical representations in deep networks.
- The analysis does not assume specific structures in latent variables (e.g., causality); integrating this is a future direction.
- The study primarily focuses on low-complexity bias; other implicit biases and more fine-grained complexity measures warrant investigation.
- The requirement of p1>0 (explicitly sampling degree-1 tasks) in Theorem 3.7 is a limitation the authors conjecture might be removable in many settings.
In summary, the paper provides a novel theoretical framework using Boolean analysis to understand when and how neural networks might learn world models. It highlights the critical roles of multi-task learning, a bias towards low-degree functions (in both models and task distributions), and architectural compatibility with the underlying structure of the data generation process. The findings offer theoretical support for empirical observations in LLMs and suggest design principles for building more robust and generalizable AI systems.