Efficient Centroids Module
- The paper introduces a novel centroids-based classification mechanism that minimizes catastrophic forgetting by representing classes in a shared embedding space.
- It employs both exact and online running averages for centroid updates, ensuring computational efficiency and reduced memory overhead.
- The approach integrates explicit regularization and model snapshot techniques to preserve embedding geometry, outperforming rehearsal and EWC baselines.
An Efficient Centroids Module constitutes a lightweight, embedding-centric approach for continual learning classifiers to address catastrophic forgetting by leveraging class prototypes in the model’s latent space. Unlike standard rehearsal or parameter-regularization techniques, this module focuses on representing each class as a centroid in embedding space and using these centroids both for classification and explicit regularization. This design achieves high accuracy on all tasks in task-incremental or class-incremental lifelong learning, markedly reducing memory and computational overhead.
1. Formal Definition of Embedding-Space Centroids
Let denote the backbone neural network extracting task-agnostic embeddings and denote the i-th task head. For each new task , a random support set (typically 50–200 labeled examples per class) is sampled. For class within task , define
where collects support examples of class . These centroids are averaged representations in the output space of and parameterize the module.
2. Centroid Initialization and Updates
At the start of task , centroids are initialized by averaging the current network’s embeddings of the support set. During task training, centroids may be:
- Recomputed exactly from the entire support set at each step (O)
- Maintained as an online running average:
with mini-batch , and tracking the number of seen points.
After completing task , only the final set of centroids is retained, not the raw data.
3. Centroid-Based Classification
Prediction for input from task is based on distance in embedding space:
- Compute embedding
- Euclidean distances
- Posterior:
- At inference (Task-Incremental),
- At inference (Class-Incremental), embeddings may be projected into a shared space via small nets, then assigned by global minimal distance
4. Continual-Learning Regularization
To mitigate forgetting, the module preserves the geometry of embedding spaces of previous tasks:
- After each task , a frozen copy of the model is retained for each previous head and backbone
- For each new sample in task , compute regularization
- Composite loss per sample in task :
- is a regularization hyperparameter
5. Algorithmic Workflow
Training (Task )
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
Input: current network ψ, task heads {f₁,…,f_{t}}
1. Sample support set S_t; partition to S_t^k by class
2. Initialize c_t^k = (1/|S_t^k|) ∑_{x in S_t^k} f_t(ψ(x))
3. If t>1: copy ψ, f₁…f_{t-1} to 'old_model'
4. For epoch = 1,…,E:
For batch B in task t:
For (x,k) in B, compute e_t(x) = f_t(ψ(x))
Compute p(y=k|x,t) as softmax(-||e_t(x)-c_t^j||_2)
If t>1, compute
R_batch = (1/t) ∑_{j<t} ∑_{x∈B} || old_model.e_j(x) - e_j(x) ||_2
L = cross-entropy + λ R_batch
Backpropagate L, update ψ and f₁,…,f_t
Optionally, recompute c_t^k from S_t
5. Discard S_t; keep {c_t^k} |
Inference
- TIL: given and task id , assign
- CIL: project embeddings, assign to global nearest centroid
6. Computational and Memory Complexity
- Storage: One -dimensional centroid per class per task; total is floats
- No raw example storage for TIL; only small replay buffer needed for CIL in realistic scenarios; module is always O().
- Classification: O() per sample—broadcast pairwise distances using efficient tensor logic.
- Training: Centroid re-averaging is O() and amortized. Embedding-regularization is O() per sample, with (number of model weights).
- Empirical runtime: 10%–30% faster than rehearsal baselines on large buffers; 2–4 faster than EWC/OEWC when .
| Method | Memory Footprint | Train Time (per update) |
|---|---|---|
| Centroids Matching (proposed) | O() at test, at train | |
| Rehearsal | O() additional FWD/BWD | |
| EWC/OEWC (reg) | O() extra grad/Fisher ops |
7. Software and Practical Implementation
- Centroids stored as PyTorch buffers, e.g.
self.register_buffer('centroids', torch.zeros(num_classes, D)), ensuring correct device placement and non-participation in gradients. - For mini-batch updates, use running sums and counts for numerically stable online averaging.
- Distance computation exploits PyTorch broadcasting:
1 2
dists = (e.unsqueeze(1) - centroids.unsqueeze(0)).pow(2).sum(-1) logits = -dists
- Old model snapshots for regularization use
copy.deepcopy(model)withtorch.no_grad()and state management throughmodel.state_dict(). - Centroids are sufficiently small to be included in checkpoints and support full state restoration.
8. Empirical and Methodological Impact
The module enables a continual learning system that is provably geometry-preserving in task-embedding spaces. Catastrophic forgetting is minimized due to explicit regularization on embedding drift, rather than parameter-level constraints or resource-intensive rehearsal. The result is a scalable, resource-efficient, and high-performing continual-learning pipeline suitable for realistic, non-idealized multi-task scenarios. The experimental evaluation demonstrates accuracy gains on several benchmarks and shows clear memory and runtime advantages over standard rehearsal and regularization techniques (Pomponi et al., 2022).
References
- "Centroids Matching: an efficient Continual Learning approach operating in the embedding space" (Pomponi et al., 2022)