Papers
Topics
Authors
Recent
Search
2000 character limit reached

When do neural networks learn world models?

Published 13 Feb 2025 in cs.LG | (2502.09297v4)

Abstract: Humans develop world models that capture the underlying generation process of data. Whether neural networks can learn similar world models remains an open problem. In this work, we present the first theoretical results for this problem, showing that in a multi-task setting, models with a low-degree bias provably recover latent data-generating variables under mild assumptions--even if proxy tasks involve complex, non-linear functions of the latents. However, such recovery is sensitive to model architecture. Our analysis leverages Boolean models of task solutions via the Fourier-Walsh transform and introduces new techniques for analyzing invertible Boolean transforms, which may be of independent interest. We illustrate the algorithmic implications of our results and connect them to related research areas, including self-supervised learning, out-of-distribution generalization, and the linear representation hypothesis in LLMs.

Summary

  • The paper demonstrates that multi-task learning combined with a low-degree (low-complexity) bias enables neural networks to recover latent world models.
  • The paper leverages Boolean Fourier-Walsh analysis to quantify representation complexity and establish conditions favoring shared latent recovery.
  • The paper shows that bias in task selection and architecture compatibility yield improved out-of-distribution generalization in complex proxy tasks.

This paper, "When do Neural Networks Learn World Models?", investigates the conditions under which neural networks can recover latent data-generating variables, a process equated with learning "world models." The authors provide theoretical results suggesting that a combination of multi-task learning, a low-complexity bias (specifically a low-degree bias in Boolean function space), and compatible model architecture enables neural networks to learn these underlying structures, even when proxy tasks are complex, non-linear functions of the latents.

1. Formulation of Learning World Models

The paper defines learning world models as the recovery of latent variables (zz) that generate observed data (xx). The data generation process is defined as:

  1. Sample zp(z)z \sim p(z) from a latent space Z\mathcal{Z}.
  2. Generate x=ψ(z)x = \psi(z) through an invertible, non-linear function ψ:ZX\psi: \mathcal{Z} \to \mathcal{X}.

A representation Φ:XZ\Phi: \mathcal{X} \to \mathcal{Z} learns the world model up to a set of simple transforms T\mathcal{T} (e.g., linear transforms) if there exists TTT \in \mathcal{T} such that Φ(x)=T(z)\Phi(x) = T(z) for x=ψ(z)x = \psi(z).

Models are trained on proxy tasks, h:XYh: \mathcal{X} \to \mathcal{Y}. A realization f:XYf: \mathcal{X} \to \mathcal{Y} of task hh satisfies f(x)=h(x)f(x) = h(x) for all observed data. The goal is to learn a hierarchical realization gΦR(h)g \circ \Phi \in \mathcal{R}(h) where Φ\Phi learns the world model.

A core challenge is the non-identifiability of latent variables: multiple solutions (different Φ\Phi') can fit the observed data equally well, making it impossible to recover the true latents based on task performance alone. The paper proposes that the implicit bias of neural networks towards low-complexity solutions can resolve this ambiguity.

2. Complexity Measures via Boolean Models

To formalize "low-complexity," the paper models all variables (observed data xx, latent variables zz, and internal representations) as Boolean strings (e.g., x{1,1}m,z{1,1}dx \in \{-1,1\}^m, z \in \{-1,1\}^d). This allows leveraging the Fourier-Walsh transform for Boolean functions: Any function f:{±1}nRf: \{\pm 1\}^n \to \mathbb{R} can be uniquely expressed as f(x)=S[n]f^(S)χS(x)f(\mathbf{x}) = \sum_{S\subseteq [n]}\hat{f}(S)\chi_S(\mathbf{x}), where χS(x)=iSxi\chi_S(\mathbf{x}) = \prod_{i\in S}x_i are parity functions.

The complexity of a Boolean function ff is measured by its degree: deg(f)=max{S:f^(S)0}deg(f) = \max\{|S| : \hat{f}(S) \ne 0\}. For hierarchical realizations h(1)h(q)h_{(1)} \circ \dots \circ h_{(q)}, the realization degree is i[q]deg(h(i))\sum_{i \in [q]} deg(h_{(i)}). Models with a low-degree bias are assumed to minimize this realization degree.

Other relevant definitions:

  • Min-degree solutions H(h)H(h): Solutions in R(h)\mathcal{R}(h) that minimize deg(f)deg(f).
  • Conditional degree deg(hΦ)deg(h|\Phi): deg(h)max{deg(g):gΦR(h)}deg(h^*) - \max\{deg(g) : g \circ \Phi \in \mathcal{R}(h)\}, where hH(h)h^* \in H(h). A positive conditional degree means Φ\Phi simplifies the task.

3. Theoretical Analysis

The paper presents several key theoretical findings:

  • Theorem 3.1 (Single-task learning): For a single task hh, a flat realization hH(h)h^* \in H(h) always has a realization degree less than or equal to any hierarchical realization gΦg \circ \Phi. This suggests that without further constraints, models prefer not to learn intermediate representations for single tasks.
    • Implication: Learning representations often requires more than a single, isolated task.
  • Theorem 3.2 (Multi-task learning): In a multi-task setting with nn tasks (h1,,hn)(h_1, \dots, h_n), a hierarchical realization gΦg \circ \Phi^* (where g=(g1,,gn)g=(g_1, \dots, g_n) and Φ\Phi^* is a min-degree realization of Φ\Phi) is favored over a set of independent flat realizations h=(h1,,hn)h^*=(h_1^*, \dots, h_n^*) if the sum of conditional degrees on Φ\Phi^* is sufficiently large:
    • Implication: Multi-task learning, like next-token prediction in LLMs, can drive the learning of shared, general-purpose representations if these representations simplify a sufficient number of tasks. Proxy tasks should be chosen such that conditioning on the desired representation Φ\Phi significantly reduces task complexity (i.e., deg(hiΦ)>0deg(h_i | \Phi) > 0).
  • Theorem 3.4 (Representational no free lunch): If proxy tasks are sampled uniformly from all possible functions of the true latents (Fdψ1{\mathcal{F}^d \circ \psi^{-1}}), then as the number of tasks nn \to \infty, all viable representations (those that can solve all tasks via some bijective transform TT from the true latents) have the same task-averaged realization complexity.
    • Implication: Without a bias in the task distribution, world models are only learnable up to arbitrary bijective (highly non-linear) transforms.
  • kk-degree tasks: Tasks hh are kk-degree if hFkdψ1h \in \mathcal{F}^d_k \circ \psi^{-1}, meaning they can be solved by a function of degree k\le k on top of the true latents zz.
  • Theorem 3.7 (World model learning): If tasks are sampled such that lower-degree functions of the true latents are preferred (e.g., sampling from kk-degree tasks with pk>0p_k > 0 for k<dk<d and non-zero p1p_1), then the representation Φ\Phi^* minimizing the average realization degree will learn the world model up to negations and permutations of the true latent variables zz.
    • Implication: A bias in the task distribution towards simpler functions of the true latents is crucial. This result aligns with the "linear representation hypothesis" in LLMs, where interpretable features are often found as linear directions in activation space (permutations and negations are degree-1 Boolean functions, analogous to linear functions).
  • Theorem 3.8 (Benefits of learning world models): In an out-of-distribution (OOD) generalization setting where training latents are from a Hamming ball BrB_r (subset of {1,1}d\{-1,1\}^d) and test latents are from the full space {1,1}d\{-1,1\}^d, a model Φ\Phi^* that recovers latents (as in Thm 3.7) achieves zero test MSE on a parity task hh (where hψh \circ \psi is a parity function of degree qq), provided deg(hψ1)deg(h|\psi^{-1}) is large enough. In contrast, any flat realization hh^* has test MSE >1>1.
    • Implication: World models offer provable OOD generalization benefits, especially when downstream tasks become simpler given the true latents.
  • Theorem 3.10 (Impact of model architecture - Basis Compatibility): Model architecture influences how functions are represented. If the Fourier-Walsh basis χS\chi_S is considered the "natural" basis, a model architecture can be seen as inducing a new basis via a transform UU. If UU is compatible (i.e., degU(f)=deg(f)deg_U(f) = deg(f) because deg(U(χS))=deg(χS)deg(U(\chi_S)) = deg(\chi_S)), then Φ\Phi^* still learns the world model up to negations and permutations. If UU is incompatible, Φ\Phi^* might recover a more complex, non-linear transformation of the true latents.
    • Implication: The choice of model architecture (e.g., activation functions) should align with the "natural" complexity of the tasks and latent variables to facilitate world model learning. Architectures should be biased towards representing low-degree functions in the natural basis efficiently.

4. Algorithmic Implications and Experiments

The paper demonstrates the implications of basis compatibility with two experiments:

  • Polynomial Extrapolation:
    • Standard ReLU MLPs fail to extrapolate polynomials of degree >1>1 beyond the training region. This is attributed to the incompatibility of ReLU activations with the polynomial basis {1,x,x2,}\{1, x, x^2, \dots\}.
    • Solution: An MLP where half of the ReLUs in each layer are replaced by identity (σ(x)=x\sigma(x)=x) and quadratic (σ(x)=x2\sigma(x)=x^2) activations.
    • Result: This modified MLP significantly improves extrapolation performance for degree-2 and degree-3 polynomials, suggesting it learns a more compatible basis.
    • 1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      
      # Pseudocode for modified MLP layer
      def modified_mlp_layer(input_tensor, weights, biases):
          linear_output = input_tensor @ weights + biases
          
          # Split channels for different activations
          out_relu = relu(linear_output[:, :N_relu])
          out_identity = linear_output[:, N_relu : N_relu+N_identity]
          out_quadratic = linear_output[:, N_relu+N_identity:] ** 2
          
          return concatenate(out_relu, out_identity, out_quadratic)
  • Learning Physical Laws:
    • Tasks involved predicting single-object parabolic motion and two-object elastic collision motion, evaluated on OOD settings (different initial velocities/sizes).
    • A Transformer model with its MLP components replaced by the modified MLP (with GELU instead of remaining ReLUs) was compared against a standard Transformer.
    • Result: The modified Transformer achieved lower prediction error in OOD settings, suggesting it better captured the underlying physical laws.

5. Limitations and Future Work

  • The model of representation Φ\Phi could be more structured to reflect hierarchical representations in deep networks.
  • The analysis does not assume specific structures in latent variables (e.g., causality); integrating this is a future direction.
  • The study primarily focuses on low-complexity bias; other implicit biases and more fine-grained complexity measures warrant investigation.
  • The requirement of p1>0p_1 > 0 (explicitly sampling degree-1 tasks) in Theorem 3.7 is a limitation the authors conjecture might be removable in many settings.

In summary, the paper provides a novel theoretical framework using Boolean analysis to understand when and how neural networks might learn world models. It highlights the critical roles of multi-task learning, a bias towards low-degree functions (in both models and task distributions), and architectural compatibility with the underlying structure of the data generation process. The findings offer theoretical support for empirical observations in LLMs and suggest design principles for building more robust and generalizable AI systems.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Continue Learning

We haven't generated follow-up questions for this paper yet.

Authors (3)

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 1 tweet with 1 like about this paper.