Papers
Topics
Authors
Recent
Search
2000 character limit reached

Efficient Centroids Module

Updated 1 December 2025
  • The paper introduces a novel centroids-based classification mechanism that minimizes catastrophic forgetting by representing classes in a shared embedding space.
  • It employs both exact and online running averages for centroid updates, ensuring computational efficiency and reduced memory overhead.
  • The approach integrates explicit regularization and model snapshot techniques to preserve embedding geometry, outperforming rehearsal and EWC baselines.

An Efficient Centroids Module constitutes a lightweight, embedding-centric approach for continual learning classifiers to address catastrophic forgetting by leveraging class prototypes in the model’s latent space. Unlike standard rehearsal or parameter-regularization techniques, this module focuses on representing each class as a centroid in embedding space and using these centroids both for classification and explicit regularization. This design achieves high accuracy on all tasks in task-incremental or class-incremental lifelong learning, markedly reducing memory and computational overhead.

1. Formal Definition of Embedding-Space Centroids

Let ψ()\psi(\cdot) denote the backbone neural network extracting task-agnostic embeddings and fi()f_i(\cdot) denote the i-th task head. For each new task ii, a random support set SiS_i (typically 50–200 labeled examples per class) is sampled. For class kk within task ii, define

cik=1Sik(x,y)Sikfi(ψ(x)),c_i^k = \frac{1}{|S_i^k|} \sum_{(x,y)\in S_i^k} f_i(\psi(x)),

where SikSiS_i^k \subset S_i collects support examples of class kk. These centroids are averaged representations in the output space of fiψf_i\circ\psi and parameterize the module.

2. Centroid Initialization and Updates

At the start of task ii, centroids are initialized by averaging the current network’s embeddings of the support set. During task training, centroids may be:

  • Recomputed exactly from the entire support set SikS_i^k at each step (OSik|S_i^k|)
  • Maintained as an online running average:

cikNkcik+xBkfi(ψ(x))Nk+Bk,c_i^k \leftarrow \frac{N_k \cdot c_i^k + \sum_{x\in B_k} f_i(\psi(x))}{N_k + |B_k|},

with mini-batch BkB_k, and NkN_k tracking the number of seen points.

After completing task ii, only the final set of centroids {cik}k\{c_i^k\}_k is retained, not the raw data.

3. Centroid-Based Classification

Prediction for input xx from task ii is based on distance in embedding space:

  • Compute embedding ei(x)=fi(ψ(x))e_i(x) = f_i(\psi(x))
  • Euclidean distances d(cik,ei(x))=cikei(x)2d(c_i^k, e_i(x)) = \|c_i^k - e_i(x)\|_2
  • Posterior:

p(y=kx,i)=exp(d(cik,ei(x)))kexp(d(cik,ei(x)))p(y=k|x,i) = \frac{\exp(-d(c_i^k, e_i(x)))}{\sum_{k'} \exp(-d(c_i^{k'}, e_i(x)))}

  • At inference (Task-Incremental), y^=argminkei(x)cik22\hat y = \arg\min_k \|e_i(x) - c_i^k\|_2^2
  • At inference (Class-Incremental), embeddings may be projected into a shared space via small nets, then assigned by global minimal distance

4. Continual-Learning Regularization

To mitigate forgetting, the module preserves the geometry of embedding spaces of previous tasks:

  • After each task i1i-1, a frozen copy of the model (ej())(\overline{e}_j(\cdot)) is retained for each previous head and backbone
  • For each new sample xx in task tt, compute regularization

R(x,t)=1tj<td(ej(x),ej(x))R(x,t) = \frac{1}{t} \sum_{j < t} d(\overline{e}_j(x), e_j(x))

  • Composite loss per sample (x,k)(x,k) in task tt:

Lt(x,k)=logp(y=kx,t)+λR(x,t)L_t(x,k) = -\log p(y=k|x,t) + \lambda R(x,t)

  • λ\lambda is a regularization hyperparameter

5. Algorithmic Workflow

Training (Task tt)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Input: current network ψ, task heads {f,,f_{t}}
1. Sample support set S_t; partition to S_t^k by class
2. Initialize c_t^k = (1/|S_t^k|) _{x in S_t^k} f_t(ψ(x))
3. If t>1: copy ψ, ff_{t-1} to 'old_model'
4. For epoch = 1,,E:
   For batch B in task t:
     For (x,k) in B, compute e_t(x) = f_t(ψ(x))
     Compute p(y=k|x,t) as softmax(-||e_t(x)-c_t^j||_2)
     If t>1, compute
         R_batch = (1/t) _{j<t} _{xB} || old_model.e_j(x) - e_j(x) ||_2
     L = cross-entropy + λ R_batch
     Backpropagate L, update ψ and f,,f_t
   Optionally, recompute c_t^k from S_t
5. Discard S_t; keep {c_t^k}

Inference

  • TIL: given xx and task id ii, assign y^=argminkei(x)cik22\hat y = \arg\min_k \|e_i(x)-c_i^k\|_2^2
  • CIL: project embeddings, assign to global nearest centroid

6. Computational and Memory Complexity

  • Storage: One DD-dimensional centroid per class per task; total is (#classes)D(\#classes) \cdot D floats
  • No raw example storage for TIL; only small replay buffer needed for CIL in realistic scenarios; module is always O(CDC \cdot D).
  • Classification: O(CDC \cdot D) per sample—broadcast pairwise distances using efficient tensor logic.
  • Training: Centroid re-averaging is O(SiD|S_i| \cdot D) and amortized. Embedding-regularization is O(tDt \cdot D) per sample, with tPt \ll P (number of model weights).
  • Empirical runtime: 10%–30% faster than rehearsal baselines on large buffers; 2–4×\times faster than EWC/OEWC when PCDP\gg C\cdot D.
Method Memory Footprint Train Time (per update)
Centroids Matching (proposed) (#classes)D(\#\text{classes})\cdot D O(CDC D) at test, SiD|S_i| D at train
Rehearsal MICM \cdot I \cdot C O(MM) additional FWD/BWD
EWC/OEWC (reg) PP O(PP) extra grad/Fisher ops

7. Software and Practical Implementation

  • Centroids stored as PyTorch buffers, e.g. self.register_buffer('centroids', torch.zeros(num_classes, D)), ensuring correct device placement and non-participation in gradients.
  • For mini-batch updates, use running sums and counts for numerically stable online averaging.
  • Distance computation exploits PyTorch broadcasting:
    1
    2
    
    dists = (e.unsqueeze(1) - centroids.unsqueeze(0)).pow(2).sum(-1)
    logits = -dists
  • Old model snapshots for regularization use copy.deepcopy(model) with torch.no_grad() and state management through model.state_dict().
  • Centroids are sufficiently small to be included in checkpoints and support full state restoration.

8. Empirical and Methodological Impact

The module enables a continual learning system that is provably geometry-preserving in task-embedding spaces. Catastrophic forgetting is minimized due to explicit regularization on embedding drift, rather than parameter-level constraints or resource-intensive rehearsal. The result is a scalable, resource-efficient, and high-performing continual-learning pipeline suitable for realistic, non-idealized multi-task scenarios. The experimental evaluation demonstrates accuracy gains on several benchmarks and shows clear memory and runtime advantages over standard rehearsal and regularization techniques (Pomponi et al., 2022).

References

  • "Centroids Matching: an efficient Continual Learning approach operating in the embedding space" (Pomponi et al., 2022)
Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Efficient Centroids Module.