Overcoming catastrophic forgetting in neural networks (1612.00796v2)
Abstract: The ability to learn tasks in a sequential fashion is crucial to the development of artificial intelligence. Neural networks are not, in general, capable of this and it has been widely thought that catastrophic forgetting is an inevitable feature of connectionist models. We show that it is possible to overcome this limitation and train networks that can maintain expertise on tasks which they have not experienced for a long time. Our approach remembers old tasks by selectively slowing down learning on the weights important for those tasks. We demonstrate our approach is scalable and effective by solving a set of classification tasks based on the MNIST hand written digit dataset and by learning several Atari 2600 games sequentially.
Summary
- The paper introduces Elastic Weight Consolidation (EWC), which mitigates catastrophic forgetting by selectively protecting neural network weights crucial to previously learned tasks.
- It leverages a Bayesian framework and the Fisher Information Matrix to assign dynamic learning penalties, ensuring important parameters remain stable during new training.
- Empirical results on permuted MNIST and Atari games demonstrate that EWC outperforms standard methods by effectively balancing new learning with retention of prior knowledge.
This paper, "Overcoming catastrophic forgetting in neural networks" (1612.00796), introduces Elastic Weight Consolidation (EWC), a novel algorithm designed to address the critical challenge of catastrophic forgetting in neural networks when learning tasks sequentially.
The Problem: Traditional artificial neural networks suffer from catastrophic forgetting. When trained on a new task (Task B) after mastering a previous one (Task A), the network's weights are adjusted to optimize performance on Task B, often drastically altering the weights important for Task A, leading to a severe loss of performance on the old task. This makes sequential learning of multiple tasks difficult and often requires keeping data from all previous tasks available for joint training (multitask learning), which is impractical for a large number of tasks.
Biological Inspiration: The algorithm is inspired by synaptic consolidation in the mammalian brain, where synapses crucial for previously learned skills become less plastic and more stable over time. This biological mechanism protects existing knowledge while allowing new learning to occur in other parts of the network or through adjustments of less critical synapses.
The Solution: Elastic Weight Consolidation (EWC): EWC mimics this synaptic consolidation by selectively slowing down the learning rate for weights deemed important to previously learned tasks. When training on a new task B after learning task A, EWC adds a penalty term to the loss function of task B. This penalty encourages the network parameters (θ) to stay close to their values (θA∗) after learning task A, but the strength of this pull is not uniform.
Theoretical Basis: EWC is grounded in a Bayesian perspective. Learning a task is seen as inferring the posterior distribution of the network parameters given the data. When learning task B after task A, the posterior distribution after task A acts as a prior for learning task B. The true posterior is intractable, so EWC approximates it using a Gaussian distribution centered at the learned parameters θA∗. The variance (or inverse precision) of this Gaussian is approximated by the inverse of the diagonal of the Fisher Information Matrix (F). The Fisher Information Matrix measures how sensitive the model's output (and thus the loss) is to changes in each parameter. Parameters with high Fisher information are crucial for task A, while those with low Fisher information are less important.
This leads to the EWC loss function for task B:
L(θ)=LB(θ)+i∑2λFi(θi−θA,i∗)2
Here:
- LB(θ) is the standard loss function for the current task (Task B).
- i indexes each parameter of the network.
- θi is the current value of parameter i.
- θA,i∗ is the value of parameter i after training on Task A.
- Fi is the diagonal element of the Fisher Information Matrix for parameter i calculated on Task A's data. This term quantifies the importance of parameter i for Task A.
- λ is a hyperparameter that scales the importance of the old task constraint relative to learning the new task.
For subsequent tasks, the penalty accumulates, summing the quadratic constraints from all previous tasks. Computing the diagonal of the Fisher Information Matrix is computationally efficient, requiring only first-order derivatives and having a runtime linear in the number of parameters and data samples.
Implementation and Application - Supervised Learning (Permuted MNIST):
The paper demonstrates EWC on a sequence of supervised learning tasks derived from the MNIST dataset. Each task consists of classifying digits, but the input pixels are subjected to a different, random permutation for each task. This setup ensures each task requires a distinct mapping from input to output.
- Setup: A fully connected network is trained sequentially on multiple permuted MNIST tasks.
- Comparison: EWC is compared against plain Stochastic Gradient Descent (SGD) and SGD with L2 regularization.
- Results: Plain SGD shows severe catastrophic forgetting. L2 regularization prevents forgetting to some extent but hinders the learning of new tasks by constraining all weights equally. EWC effectively learns new tasks while largely preserving performance on old ones, demonstrating scalability to a significant number of sequential tasks.
- Analysis: By analyzing the overlap of the Fisher Information Matrices between tasks, the authors show that EWC doesn't simply partition the network into disjoint subnetworks for each task. While earlier layers might show less overlap for dissimilar tasks (reflecting input domain differences), later layers demonstrate significant overlap, indicating that EWC allows for sharing of representations across tasks when possible (e.g., for the shared output classification task).
Implementation and Application - Reinforcement Learning (Atari 2600):
EWC is also applied to a more complex domain: sequential learning of multiple Atari 2600 games using Deep Q Networks (DQNs).
- Setup: A modified DQN agent learns a sequence of 10 randomly chosen Atari games. The agent is exposed to games sequentially, with potential returns to previously seen games.
- Modifications: The agent architecture is similar to a standard DQN but includes:
- A larger network capacity.
- Task-specific biases and multiplicative gains at each layer to allow some specialization.
- A task-recognition module (based on a Hidden Markov Model and a "forget me not" process) to infer the current game being played based on observations.
- Separate experience replay buffers for each recognized task.
- The EWC penalty is applied to protect knowledge of games after sufficient training time on them.
- Results: Plain SGD performs poorly, exhibiting catastrophic forgetting and failing to learn more than one game effectively. EWC enables the agent to learn and retain performance on multiple games sequentially, achieving a significantly higher cumulative score across tasks. Explicitly providing the correct task label (instead of using the task-recognition module) results in only a modest performance improvement, suggesting the task recognition module is reasonably effective.
- Analysis: Empirical analysis of parameter sensitivity shows that perturbing weights according to the inverse Fisher diagonal has less impact on performance than uniform perturbations, validating the Fisher's utility as an importance measure. However, perturbing weights estimated to be unimportant by the Fisher (in the null space) still affects performance, suggesting the diagonal Fisher approximation might underestimate true parameter uncertainty.
Discussion and Future Work:
The paper concludes that EWC is an effective and scalable method for preventing catastrophic forgetting in neural networks, applicable to both supervised and reinforcement learning. Its grounding in Bayesian principles and parallels to biological synaptic consolidation are highlighted. While effective, the reliance on the diagonal Fisher approximation might be a limitation (underestimating parameter uncertainty), suggesting potential improvements could come from more sophisticated Bayesian methods like those used in Bayesian neural networks. The success of EWC provides computational support for neurobiological theories suggesting the importance of synaptic consolidation and uncertainty encoding for continual learning in the brain.
Related Papers
- Negotiated Representations to Prevent Forgetting in Machine Learning Applications (2023)
- Pseudo-Rehearsal: Achieving Deep Reinforcement Learning without Catastrophic Forgetting (2018)
- Reinforced Continual Learning (2018)
- Center Loss Regularization for Continual Learning (2021)
- Continual learning under domain transfer with sparse synaptic bursting (2021)