Exponential Weighted Boundary Loss (EWC)
- EWC is a regularization technique that anchors neural network parameters crucial for prior tasks to mitigate catastrophic forgetting.
- It introduces a quadratic penalty based on the Fisher information matrix to preserve vital information from previous tasks during training.
- Empirical results on datasets like Permuted-MNIST and Sequential Atari demonstrate its effectiveness in maintaining low error rates and stable performance.
Exponential Weighted Boundary Loss (EWC), more precisely known as Elastic Weight Consolidation, is a regularization technique designed to address catastrophic forgetting in sequential learning scenarios. Catastrophic forgetting is the phenomena where a neural network, when trained on multiple tasks in sequence, loses performance on previously learned tasks as it updates parameters to solve new tasks. EWC mitigates forgetting by selectively constraining parameters critical to past tasks, thereby enabling neural networks to maintain expertise on earlier tasks even after extensive training on new, unrelated data (Kirkpatrick et al., 2016).
1. Loss Function and Optimization Objective
EWC introduces a quadratic penalty to the standard loss function when learning a new task after completing a previous one. The total loss when training on task after task is defined as:
where:
- denotes all network parameters (including weights and biases),
- is the standard loss for task ,
- is the parameter vector optimized on task ,
- is the diagonal element of the Fisher information matrix at ,
- is a hyperparameter modulating the penalty strength.
The penalty anchors parameters to their optimal values for the previous task, weighted by an importance factor . The larger , the higher the penalty for deviating from . For supervised classification or policy learning, quantifies the average squared sensitivity (gradient) of the log-likelihood with respect to each parameter, approximated as:
2. Bayesian and Laplace Approximation Perspective
EWC has a principled Bayesian interpretation. After training on data for task , the posterior is ; for task , this posterior acts as the prior. The posterior on both tasks is:
EWC approximates as a Gaussian centered at with precision given by the Fisher information:
By maximizing the joint log-posterior, the augmented loss function emerges. The addition of allows the empirical tuning of the regularization effect. Hence, the quadratic consolidation penalty is a Laplace (second-order) approximation of the log-posterior of parameters after task , interpreted as a locally quadratic constraint on the parameter space (Kirkpatrick et al., 2016).
3. Sequential Learning Procedure
The EWC method for continual learning is operationalized as follows:
- For a sequence of tasks and regularization hyperparameter :
- Initialize parameters randomly.
- For each task to :
- If , add the penalty
to the optimizer’s loss, with the cumulative Fisher from all previous tasks. - Train on task using standard optimizers (SGD, RMSprop) until convergence (). - Estimate by accumulating diagonal squared gradients of the log-likelihood over mini-batches sampled from . - Update cumulative precision: for , ; else .
This procedure ensures that parameters crucial for previous tasks are protected during optimization for subsequent tasks by the sum-of-quadratics term (Kirkpatrick et al., 2016).
4. Hyperparameter Selection and Implementation Considerations
Key hyperparameters and details include:
- Regularization strength : Controls the trade-off between retaining performance on old tasks and learning new tasks. Small increases forgetting; large impedes new learning. Optimal values are determined via cross-validation. For example, on MNIST, typical values range from 1 to 100; for Atari, values near 400 are used.
- Number of mini-batches : Stable Fisher estimates on Atari are obtained with .
- Batch size and optimizer: These follow standard practice for the problem domain (e.g., batch size 200 for MNIST, 32 for DQN replay).
- Penalty accumulation: EWC can accumulate quadratic penalties either by summing all prior tasks' terms or by one cumulative vector and anchor; both methods are mathematically equivalent.
- Scalability: Each new task adds one diagonal vector (Fisher) and anchor, with linear scaling in the number of tasks.
5. Empirical Results on Permuted-MNIST and Sequential Atari
Extensive experiments demonstrate EWC's efficacy:
Permuted-MNIST Sequence (10 Tasks):
- Baseline SGD exhibits catastrophic forgetting; after all tasks, error on the initial task grows from baseline to approximately 90%.
- Uniform regularization underfits new tasks significantly.
- EWC retains high accuracy across tasks; error rates on early tasks remain near 2–3% after training on all tasks.
- Fisher-overlap analysis reveals parameter sharing when tasks are similar and allocation of new weights for divergent tasks (Kirkpatrick et al., 2016).
Sequential Atari 2600 (10 Games):
- Standard DQN, trained sequentially, achieves an aggregate human-normalized score below 1, i.e., it only retains competency in one game.
- DQN augmented with EWC achieves a steadily rising aggregate score, reaching 6–8 out of 10 across all games, without expanding network capacity.
- Providing explicit task labels offers only marginal improvement beyond EWC with latent task recognition.
- Weight-perturbation validated that the Fisher information reliably predicts parameter importance: perturbations along low-F directions degrade performance substantially less.
These results support the boundary-consolidation principle, where parameter “stiffness” induced by EWC prevents catastrophic forgetting in both supervised and reinforcement-learning domains (Kirkpatrick et al., 2016).
6. Discussion and Continual Learning Implications
EWC’s diagonal Fisher-based quadratic penalty framework is distinct from uniform or other regularizations. It leverages task-specific parameter importance, operationalized via the Fisher information, as a proxy for the posterior’s precision. This approach enables effective sequential multi-task learning without replay buffers or integration of old task data. The approach scales linearly with the number of tasks, offering practical feasibility for multi-task continual learning settings (Kirkpatrick et al., 2016). A plausible implication is that future work could extend beyond diagonal approximations to leverage full-rank covariance structures or enhance consolidation via more expressive posterior approximations.