Generative Adversarial Networks (GANs) are powerful models for learning complex probability distributions, but they are known to be challenging to train. A common failure mode is mode collapse, where the generator focuses on producing samples from only a subset of the modes in the target distribution, failing to capture its full diversity. This paper (Durr et al., 2022 ) investigates mode collapse from a dynamical systems perspective by introducing a simplified, interpretable model that replaces the generator network with a collection of particles in data-space.
The standard GAN objective involves a generator and a discriminator , trained adversarially. The discriminator tries to distinguish real data samples () from generated samples (, where ), while the generator tries to fool the discriminator. The objective often takes the form:
where is a regularizer on the discriminator. Training typically involves alternating gradient ascent on and gradient descent on . Mode collapse manifests as the generator's distribution oscillating between different modes of the data.
Instead of tracking generator parameters , this paper studies the dynamics of the generator's outputs, , treated as particles in data-space. The evolution of these particles is related to the generator parameter updates through the Jacobian of the generator and the Neural Tangent Kernel (NTK), . For infinite-width neural networks, the NTK becomes static, fixed at initialization.
The paper proposes a simplified NTK structure for a system of particles in data-space, abstracting away the latent space inputs . This coarse-grained NTK is defined as:
Here, implies no correlation between gradients across different output dimensions, and the NTK value depends only on whether the particles are the same (, value ) or distinct (, value ). This simplification is motivated by NTK properties of wide networks with certain activations (like ReLU) when latent inputs are sampled from a high-dimensional sphere.
Under this simplified NTK, the dynamics of particle follows from the generator's objective (minimizing the discriminator's expected output):
Substituting the simplified NTK, this becomes:
where is the average discriminator gradient over all particles. This equation highlights a key dynamic: each particle's velocity is a combination of its local discriminator gradient and the average gradient across the entire particle ensemble, weighted by terms related to and .
To paper mode collapse, the paper applies this model to a 2D toy problem: generating samples from a distribution of 8 Gaussians arranged in a circle, a common benchmark for mode collapse. The "generator" is the set of 200 particles, initialized as a Gaussian cluster. The "discriminator" is a single-hidden-layer neural network. Training alternates between updating the discriminator parameters via gradient ascent on the objective and updating the particle positions via the derived gradient dynamics (Algorithm \ref{alg:gan_training_2}).
The paper quantifies mode collapse using a metric based on the entropy of the distribution of particles assigned to the nearest mode (Equation \ref{eq:mode_collapse_metric}). A low metric value (near 0) indicates particles are distributed across all 8 modes (convergence), while a high value (near ) indicates collapse to a single mode.
Experiments varying the NTK ratio and the discriminator's relative training time (, the number of discriminator steps per generator step) reveal a transition boundary between convergence and mode collapse.
- When (diagonal NTK), particles evolve independently based on local gradients, and the system converges to cover all modes (Figure \ref{fig:no_ntk_2d}).
- When is sufficiently large (e.g., 1/5), the average gradient term dominates, causing the entire particle cluster to move together, chasing discriminator minima from mode to mode – the signature of mode collapse (Figure \ref{fig:with_ntk_2d}).
The paper relates the shape of this boundary to the discriminator's ability to learn "high-frequency" spatial features. To break apart a particle cluster, the discriminator needs to create a minimum near the cluster's center, which requires learning finer spatial details (higher frequencies). The spatial scale required is proportional to . The time (or ) needed for the discriminator to learn such features depends on its frequency-dependent learning rate, . For ReLU discriminators, is known to be power-law, leading to a power-law boundary in the vs plane (Figure \ref{fig:relu_data}). For Tanh discriminators, is exponential, resulting in an exponential boundary (Figure \ref{fig:tanh_data}). This suggests that the Frequency Principle in neural networks plays a direct role in GAN convergence behavior.
The paper also shows how the model can incorporate regularizers. A kinetic energy regularizer on generator parameters, , translates to a term involving the NTK and discriminator gradients (Equation \ref{eq:reg}). This term acts analogously to damping in a physical system. Adding this regularizer to the model GAN with a collapsing value successfully restores convergence (Figure \ref{fig:reg_dyn_scatters}). Varying the regularization strength reveals under-, critically-, and over-regularized regimes, mirroring damping dynamics (Figure \ref{fig:reg_dyn}), suggesting that finding the optimal regularization strength is crucial in practice.
Practical Implications and Implementation:
- Interpreting Dynamics: The particle model offers a simplified, visualizable way to understand the complex adversarial dynamics of GANs and why mode collapse occurs due to correlated particle movement.
- Diagnosing Mode Collapse: The competition between local and average gradients ( vs ) provides an intuitive explanation for mode collapse: when the global influence () is too strong relative to the local influence () or when the discriminator cannot sufficiently minimize the average gradient over a region, particles fail to split and cover multiple modes.
- Improving Training: The findings suggest practical strategies:
- Modifying Generator Architecture: Architectures yielding a smaller ratio in their infinite-width NTK might be inherently more stable against mode collapse. While directly calculating NTK is hard, design choices could implicitly affect this ratio.
- Discriminator Training Speed: The relative learning rates of the generator and discriminator (, , ) are critical. Sufficient discriminator training is needed to learn fine-grained features and break particle clusters.
- Discriminator Properties: Discriminators with faster learning rates for high-frequency functions could potentially mitigate mode collapse more effectively, especially when particle correlation () is high.
- Regularization: Regularizers acting on generator gradients (explicitly or implicitly via the discriminator objective) can act like damping, stabilizing training and preventing oscillatory mode switching. Tuning the strength is important to avoid over-regularization.
- Implementing the Model: Simulating this model involves:
- Representing the generator distribution by points in data-space.
- Implementing the chosen discriminator network (e.g., simple ReLU/Tanh MLP).
- Implementing the training loop: sample real data, calculate discriminator loss, update discriminator via gradient ascent. Then, calculate discriminator gradients for all particles, compute the average gradient, and update particle positions using the derived dynamics equation incorporating the simplified NTK terms . Regularization terms would be added to the discriminator's objective before its update step.
- Calculating the mode collapse metric by assigning particles to the nearest mode and computing the entropy of the mode distribution.
- Limitations: The model relies on significant simplifications (static NTK, fixed , fixed particle set representing the distribution) that might not fully capture the complexities of real GAN training with dynamic NTKs, varying batch samples, and the full generator network parameter space. However, its value lies in providing an interpretable lower-dimensional system to paper fundamental dynamics.
In summary, the paper provides a valuable theoretical framework grounded in physics concepts (dynamical systems, NTK, damping) to understand mode collapse in GANs. The proposed particle model simplifies the problem while retaining key properties, allowing for concrete experiments demonstrating how internal dynamics governed by the NTK and the discriminator's properties (learning speed, frequency principle) drive the transition to mode collapse. This work provides principles that could guide the development of more stable GAN architectures and training algorithms.