Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
102 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

The Curse of Recursion: Training on Generated Data Makes Models Forget (2305.17493v3)

Published 27 May 2023 in cs.LG, cs.AI, cs.CL, cs.CR, and cs.CV

Abstract: Stable Diffusion revolutionised image creation from descriptive text. GPT-2, GPT-3(.5) and GPT-4 demonstrated astonishing performance across a variety of language tasks. ChatGPT introduced such LLMs to the general public. It is now clear that LLMs are here to stay, and will bring about drastic change in the whole ecosystem of online text and images. In this paper we consider what the future might hold. What will happen to GPT-{n} once LLMs contribute much of the language found online? We find that use of model-generated content in training causes irreversible defects in the resulting models, where tails of the original content distribution disappear. We refer to this effect as Model Collapse and show that it can occur in Variational Autoencoders, Gaussian Mixture Models and LLMs. We build theoretical intuition behind the phenomenon and portray its ubiquity amongst all learned generative models. We demonstrate that it has to be taken seriously if we are to sustain the benefits of training from large-scale data scraped from the web. Indeed, the value of data collected about genuine human interactions with systems will be increasingly valuable in the presence of content generated by LLMs in data crawled from the Internet.

The paper investigates the implications of training generative models, specifically LLMs (LLM), on data generated by previous iterations of similar models. The central finding is the discovery of a "model collapse," a degenerative process where models progressively lose the ability to represent the true underlying data distribution, with the tails of the distribution disappearing over time. This phenomenon is shown to occur in Gaussian Mixture Models (GMM), Variational Autoencoders (VAE), and LLMs, suggesting its ubiquity across learned generative models.

The authors identify two primary causes of model collapse:

  • Statistical approximation error, which arises due to the finite number of samples used during training.
  • Functional approximation error, stemming from limitations in the expressiveness of the function approximators used in the models.

The paper argues that access to the original data distribution is crucial for sustaining the benefits of training from large-scale data, especially for capturing low-probability events that are often relevant to marginalized groups and understanding complex systems. The authors propose that data about genuine human interactions with systems will be increasingly valuable in the presence of content generated by LLM in data crawled from the Internet.

The paper presents a theoretical analysis of model collapse, using simplified mathematical models to provide analytical expressions for quantities of interest. The analysis focuses on quantifying how different sources of error affect the overall approximation of the original distribution. The authors consider two cases: a discrete distribution in the absence of functional approximation error, and a single-dimensional Gaussian case that portrays how functional approximation error can compound with statistical error.

Key theoretical results include:

  • Demonstration that for discrete distributions with exact approximation, model collapse arises solely due to statistical errors from the sampling step, leading to the eventual convergence to a delta function.
  • Derivation of a lower bound on the risk, defined in terms of the Wasserstein distance from the true distribution, for a single-dimensional Gaussian. The risk diverges linearly with the number of generations, indicating that the sampling rate needs to increase superlinearly to maintain an accurate approximation of the original distribution.

The paper also presents empirical results that support the theoretical analysis. Specifically, the authors demonstrate model collapse in GMMs and VAEs trained from scratch, showing that the models progressively lose information about the tails of the distribution and converge to a distribution with very small variance.

In the context of LLMs, the paper investigates the effects of fine-tuning OPT-125m on data generated by previous iterations of the model. The results show that models trained on generated data exhibit degraded performance compared to models trained on original data. The generated data also exhibit longer tails, suggesting that the models are starting to misperceive reality based on errors introduced by their ancestors.

The authors conduct experiments with different training regimes, including training for 5 epochs with no original training data and training for 10 epochs with 10% of the original training data preserved. Both regimes lead to degraded performance, but the preservation of original data allows for better model fine-tuning and leads to only minor degradation of performance.

The paper also addresses the issue of repeating phrases in generated text, showing that explicitly encouraging models to produce non-repeating sequences does not curb the effects of model collapse.

The paper concludes by discussing the implications of model collapse for the long-term sustainability of LLM training. The authors emphasize the importance of preserving access to the original data source and distinguishing data generated by LLM from other data. They suggest that community-wide coordination may be necessary to ensure the provenance of content crawled from the Internet and to enable the training of newer versions of LLM without access to pre-LLM data or direct human-generated data.

In the theoretical analysis, the authors model the learning process with generational data as a stochastic process. At generation ii, the dataset D_i\mathcal{D}\_i consists of i.i.d. random variables Xi_jX^i\_j, where j{1,,M_i}j \in \{1, \dots, M\_i\} and M_i2M\_i \geq 2. The distribution of XiX^i is denoted as p_ip\_i, with p_0p\_0 representing the original distribution. The transition from generation ii to i+1i+1 involves estimating the distribution of samples in D_i\mathcal{D}\_i with an approximation p_θ_i+1p\_{\theta\_{i+1}}, where F_θ:p_ip_θ_i+1\mathcal{F\_\theta}: p\_i \to p\_{\theta\_{i+1}} represents the functional approximation. The dataset D_i+1\mathcal{D}\_{i+1} is then resampled from the distribution p_i+1=α_ip_θ_i+1+β_ip_i+γ_ip_0p\_{i+1} = \alpha\_i p\_{\theta\_{i+1}} + \beta\_i p\_i + \gamma\_i p\_0, with non-negative parameters α_i,β_i,γ_i\alpha\_i, \beta\_i, \gamma\_i summing up to $1$.

For the single dimensional Gaussian case, the authors consider X0N(μ,σ2)X^0 \sim \mathcal{N}(\mu, \sigma^2) and estimate the sample mean and variance using:

μ_i+1=1M_i_jXi_j\mu\_{i+1} = \frac{1}{M\_i}\sum\_j X^i\_j

  • μ_i+1\mu\_{i+1} is the estimated sample mean at generation i+1i+1
  • M_iM\_i is the sample size at generation ii
  • Xi_jX^i\_j represents the samples at generation ii

σ2_i+1=1M_i1_j(Xi_jμ_i+1)2\sigma^2\_{i+1} = \frac{1}{M\_i-1}\sum \_j(X^i\_j-\mu\_{i+1})^2

  • σ2_i+1\sigma^2\_{i+1} is the estimated sample variance at generation i+1i+1

They then derive the following expression for Xn_jX^n\_j:

$X^n\_j = \mu + \frac{\sigma}{\sqrt{M\_0}Z^1} + \frac{\sigma}{\sqrt{M\_1}\sqrt{S^1}Z^2} + \dots + \frac{\sigma}{\sqrt{M\_{n-1}\sqrt{S^1\times\dots\times S^{n-1}Z^n+\sigma\sqrt{S^1\times\dots\times S^{n}Z^n\_j}$

  • ZiZ^i are random variables distributed as N(0,1)\mathcal{N}(0, 1)
  • SiS^i are random variables distributed as 1M_i11Γ(M_i112,12)\frac{1}{M\_{i-1}-1}\Gamma(\frac{M\_{i-1}-1}{2}, \frac{1}{2})

They derive the following approximation:

Var(Xn_j)=σ2(1+nM)\operatorname{Var}(X^n\_j) = \sigma^2(1+\frac{n}{M})

The authors then use the Wasserstein-2 distance to measure the distance between the true distribution and the approximated distribution at step n+1n+1:

Rn+1_W_2:=W2_2(N(μ,σ2),N(μ_n+1,σ2_n+1))=μ_n+1μ2+σ_n+1σ2R^{n+1}\_{W\_2} := W^2\_2(\mathcal{N}(\mu,\sigma^2),\mathcal{N}(\mu\_{n+1},\sigma^2\_{n+1}))=\|\mu\_{n+1}-\mu\|^2 + \|\sigma\_{n+1}-\sigma\|^2

Finally, they calculate the risk as:

E_μ_n+1,σ_n+12[Rn+1_W_2]=σ2(1M_0+1M_1++32M_n)\mathbb{E}\_{\mu\_{n+1},\sigma\_{n+1}^2}[R^{n+1}\_{W\_2}]=\sigma^2(\frac{1}{M\_0}+\frac{1}{M\_1}+ \dots + \frac{3}{2M\_{n}})

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (6)
  1. Ilia Shumailov (72 papers)
  2. Zakhar Shumaylov (14 papers)
  3. Yiren Zhao (58 papers)
  4. Yarin Gal (170 papers)
  5. Nicolas Papernot (123 papers)
  6. Ross Anderson (46 papers)
Citations (226)
Youtube Logo Streamline Icon: https://streamlinehq.com