Averaged Gradient Episodic Memory (A-GEM)
- A-GEM is a continual learning algorithm that mitigates catastrophic forgetting by using a single projection constraint based on the average past gradient.
- It achieves competitive accuracy, e.g., around 89.1% on MNIST, while drastically reducing computational and memory costs compared to GEM.
- The method employs reservoir sampling for episodic memory and a strict two-stream evaluation protocol to ensure robustness in single-pass lifelong learning.
Averaged Gradient Episodic Memory (A-GEM) is a continual learning algorithm designed to balance computational efficiency, memory economy, and resistance to catastrophic forgetting in single-pass lifelong learning scenarios. A-GEM is an advancement over Gradient Episodic Memory (GEM), offering similar or superior accuracy with dramatically lower computational and memory costs by introducing a novel projection constraint on the average past gradient (Chaudhry et al., 2018).
1. Lifelong Learning Setup and Evaluation Protocols
In lifelong learning (LLL), the objective is to learn a predictor —for example, a neural network parameterized by —over a sequence of tasks. Each task is associated with a dataset , where is the input, is a task descriptor, and is the label. The learner observes each instance exactly once, with all tasks presented in sequence.
To mitigate catastrophic forgetting, methods maintain a small episodic memory , typically much smaller than the task dataset (). The union of all past task memories before task is denoted .
A-GEM evaluations are conducted via a two-stream protocol:
- : A held-out stream for hyper-parameter optimization, allowing arbitrary replay.
- : An evaluation stream processed in a single pass with fixed hyper-parameters.
This separation prevents information leakage from evaluation tasks during hyper-parameter search and enforces a strictly single-pass regime for reporting metrics.
2. Evaluation Metrics
A-GEM is evaluated using several metrics that quantify accuracy and the dynamics of knowledge retention and acquisition:
(a) Final Average Accuracy :
Here, is the test accuracy on task after training on all minibatches of task ; is the terminal metric.
(b) Forgetting :
This quantifies the deterioration in performance on previous tasks due to new learning.
(c) Learning Curve Area (LCA):
LCA evaluates both few-shot and progressive learning by averaging accuracy up to training steps.
3. From GEM to A-GEM: Mathematical Formulation
GEM constrains gradient updates to avoid loss increases on any previous task's memory, projecting the current gradient onto the intersection of half-spaces: with . This requires solving a quadratic program with constraints and storing all .
A-GEM simplifies the constraint to a single condition on the average past gradient , computed from a mini-batch sampled from : If , is untouched; else, the projection has a closed form: This single-constraint projection reduces computational complexity and storage, enabling scalability to longer task sequences and larger networks.
4. Algorithmic Implementation and Complexity
A-GEM maintains a global episodic memory with reservoir sampling to ensure a uniform selection from all encountered data. At each training step, a mini-batch from provides . The update is as follows:
1 2 3 4 5 6 7 8 9 10 11 |
for (x, y) in D_t:
if M ≠ ∅:
sample (x_ref, y_ref) ~ M
g_ref ← ∇_θ ℓ(f_θ(x_ref, t), y_ref)
else:
g_ref ← 0
g ← ∇_θ ℓ(f_θ(x, t), y)
if ⟨g, g_ref⟩ < 0:
g ← g − (⟨g, g_ref⟩ / ⟨g_ref, g_ref⟩) · g_ref
θ ← θ − η · g
Update M with reservoir sampling from D_t |
Complexity Table
| Method | Time (per step) | Memory |
|---|---|---|
| Vanilla | ||
| EWC | + diag-updates | |
| GEM | ||
| A-GEM | (≈) |
Here =#parameters, =mini-batch size, =activation size, =episodic memory size. In practice, A-GEM is approximately faster and more memory efficient than GEM on MNIST/CIFAR.
5. Empirical Results and Benchmark Performance
Experiments evaluate A-GEM on Permuted MNIST, Split CIFAR-100, Split CUB, and Split AWA, using MLP and ResNet architectures. A-GEM's final accuracy () matches or slightly trails GEM (e.g., 89.1% vs 89.5% on MNIST) while outperforming all regularization-based baselines (EWC, PI, MAS, RWalk) in the single-pass regime (e.g., EWC: 68%, A-GEM: 89% on MNIST). Forgetting remains lowest among methods with bounded memory.
Incorporating compositional task descriptors with a joint-embedding model ("–je" variant) further improves , 0-shot performance (LCA), and learning speed for A-GEM and other methods.
Normalized summary (Permuted MNIST, Split CIFAR):
| Method | (%) ↑ | LCA ↑ | Time ↓ | Mem ↓ |
|---|---|---|---|---|
| Vanilla | 47.9 | 0.26 | 0.06 | 0.06 |
| EWC | 68.3 | 0.27 | 0.14 | 0.14 |
| GEM | 89.5 | 0.23 | 1.00 | 1.00 |
| A-GEM | 89.1 | 0.29 | 0.14 | 0.11 |
A-GEM is thus Pareto-optimal in the joint space of accuracy, forgetting, LCA, time, and memory.
6. Ablations, Sensitivity Analyses, and Algorithmic Variants
A-GEM projections are only required on a small fraction of steps, in contrast to GEM's frequent constraints as tasks accumulate. The "Stochastic GEM" (s-GEM) variant, which randomly samples past constraints, is still more costly and slightly less effective than A-GEM.
EWC's efficacy is highly sensitive to the number of epochs and model capacity; in single-pass and small-network settings, it only marginally outperforms vanilla SGD. Only with over-parameterized models and multiple passes does EWC approach A-GEM's performance.
Hyper-parameter search spaces and selected settings are detailed in the appendix of (Chaudhry et al., 2018).
7. Key Insights, Limitations, and Future Directions
A-GEM achieves the core objectives of lifelong learning—retaining prior knowledge and enabling forward transfer—while being computationally and memory efficient. The Learning Curve Area (LCA) metric, introduced alongside A-GEM, provides a finer quantification of few-shot learning dynamics.
Observed limitations include the residual gap between single-pass continual learning (even with A-GEM) and the multi-task upper bound (IID setting). Differences in LCA among advanced continual methods converge when catastrophic forgetting is controlled; the field thus requires strategies for enhancing positive backward transfer.
Natural extensions include applying A-GEM to unsupervised, reinforcement, or streaming non-i.i.d. learning settings. The open-source codebase is provided at https://github.com/facebookresearch/agem (Chaudhry et al., 2018).