Causal Distillation in CIL
- The paper introduces a novel causal distillation method that restores the causal effect of old data via structural causal models, effectively mitigating catastrophic forgetting without relying on data replay.
- It applies a K-nearest neighbor strategy in the old feature space to compute weighted targets, replacing explicit replay with a storage-free, neighbor-aggregated distillation approach.
- Empirical studies across vision, NLP, and video tasks demonstrate significant accuracy boosts and reduced forgetting through adaptive loss terms, curriculum strategies, and cross-channel compensation.
Causal distillation in class-incremental learning (CIL) is a principled methodology that seeks to mitigate catastrophic forgetting by explicitly restoring the causal effect of previously encountered data on predictions for newly introduced classes. Unlike standard distillation, which focuses on aligning features or logits, causal distillation leverages structural causal models (SCMs) to reconstruct the influence of old data, thereby achieving anti-forgetting without incurring the storage costs of data replay. This paradigm has recently seen application in both vision and natural language processing, with theoretical and empirical support across diverse modalities (Hu et al., 2021, Zheng et al., 2022, Chen et al., 13 Jan 2025).
1. Structural Causal Models for Catastrophic Forgetting
Causal distillation reframes CIL through the lens of SCMs that elucidate the flow of information during task increments. In the standard setting, input data from previous tasks (old data ) and the incoming data batch (new input ) generate features via the old model and produce logits . The new model and compute updated and . In the absence of replay or explicit regularization, all causal paths from to 0 are blocked by the collider at 1, yielding 2 and catastrophic forgetting (Hu et al., 2021). Analogous graphical structures generalize to sequence tagging (e.g., CL-NER) where explicit treatment of “Other-Class” tokens is crucial, and to video adapters where spatial and temporal channels introduce multi-modal causal dependencies (Zheng et al., 2022, Chen et al., 13 Jan 2025).
2. Derivation and Implementation of Colliding-Effect Distillation
Causal distillation activates the blocked causal path through the “collider” 3 by conditioning on it. The formal causal effect is
4
which is zero in vanilla fine-tuning. Causal distillation replaces data replay by matching each input 5 to its 6 nearest neighbors in the old feature space 7. Distillation targets are computed as a weighted average of the new-model logits over these neighbors. Weights 8 are monotonic, descending with neighbor similarity and summing to 1. The objective is to minimize cross-entropy between the weighted prediction and the ground truth:
9
This objective is provably equivalent, in terms of restoring the causal effect, to explicit data replay but does not require storage of old exemplars (Hu et al., 2021). In specialized domains such as class-incremental NER, causal effects are isolated for both new-entity tokens and “Defined-Other” tokens, each with dedicated loss terms (cross-entropy and KL-divergence, respectively) (Zheng et al., 2022).
3. Objective Functions and Curriculum Strategies
The causal distillation loss combines multiple terms capturing causal effects from distinct slices of the data. In CIL-NER, the total effect is 0, with
- 1: Cross-entropy over new-entity tokens’ collided predictions versus gold labels,
- 2: KL-divergence over Defined-Other tokens’ collided predictions versus old-model soft targets, plus standard KD for Undefined-Other tokens.
A self-adaptive weight 3 scales 4 relative to 5:
6
where 7 and 8 are counts of old and new entity types, ensuring balanced retention and acquisition (Zheng et al., 2022).
To robustify the process against noisy old-model predictions, a curriculum learning strategy is introduced. A linearly decaying threshold 9 filters which token predictions are eligible for colliding-based KL loss each epoch, reverting lower-confidence samples to standard KD. This staged exposure mitigates overfitting to label noise and stabilizes convergence.
4. Causal Distillation for Structured and Multi-modal Adaptation
Causal distillation generalizes to modalities where class-incremental updates affect distinct representational channels. In exemplar-free video CIL, adapters are split into spatial (0) and temporal (1) modules. Causal distillation is realized as relation-recovery: the new adapters are regularized such that the relation vectors
2
remain close to those of the preceding task, using hybrid cosine losses averaged over top-K similar old examples:
3
Causal compensation further mitigates interference: the aligned gradient directions between channels (4) are leveraged to compute small corrective logits 5, inserted into the final classifier’s output. This mechanism enables constructive cross-channel recall even in the absence of replayed exemplars (Chen et al., 13 Jan 2025).
5. Algorithmic Workflow
A general class-incremental causal distillation step proceeds as follows:
| Step | Description | Source |
|---|---|---|
| 1 | Initialize new model from previous, extend classifier | (Hu et al., 2021, Zheng et al., 2022) |
| 2 | Compute old-model features for new data; build KNN index | (Hu et al., 2021, Zheng et al., 2022) |
| 3 | Partition data as needed (e.g., new-entity, Defined-Other) | (Zheng et al., 2022) |
| 4 | For each training epoch: curriculum-scheduled thresholding | (Zheng et al., 2022) |
| 5 | Batch-wise: forward pass, KNN neighbor retrieval, weighted prediction aggregation | (Hu et al., 2021, Zheng et al., 2022, Chen et al., 13 Jan 2025) |
| 6 | Compute and aggregate losses (collide CE/KL, KD, relation recovery, channel compensation) | (Hu et al., 2021, Zheng et al., 2022, Chen et al., 13 Jan 2025) |
| 7 | Backpropagate (optionally freezing old model parts/adapters) | (Hu et al., 2021, Zheng et al., 2022, Chen et al., 13 Jan 2025) |
Extensions include momentum correction at inference, which dynamically debiases the classifier output along the historical "head direction" of feature drift induced by SGD with momentum (Hu et al., 2021).
6. Empirical Evidence and Ablation Studies
Causal distillation frameworks have been empirically validated across vision and NLP. On CIFAR-100 (T=5, replay=5 exemplars), causal distillation boosts LUCIR accuracy by 9.06% over baseline, with absolute forgetting reduced by up to 16.4%. Gains persist across ImageNet-Sub/Full and in the zero-replay regime (Hu et al., 2021). In CIL-NER (OntoNotes5, i2b2, CoNLL2003), CFNER exceeds ExtendNER by 3–9 points (Micro-F1), with ablations confirming that both colliding effects, curriculum, and adaptive weights contribute substantially (Zheng et al., 2022). For video CIL (ActivityNet, Kinetics), causal distillation with channel compensation delivers SOTA gains (e.g., 59.24% vs 54.83% on ActivityNet-10, an improvement of 4.4%) (Chen et al., 13 Jan 2025).
Ablation studies reveal that omitting colliding effect distillation or compensation mechanisms leads to significant drops in performance, confirming the necessity of explicitly modeling and recovering the causal influence of old knowledge in the class-incremental pipeline (Zheng et al., 2022, Chen et al., 13 Jan 2025).
7. Theoretical and Practical Implications
Causal distillation establishes a formal correspondence between explicit data replay and storage-free, neighbor-aggregated distillation objectives, providing an end-to-end method that retains the critical early-layer and inter-class influences overlooked by standard logit or feature distillation. The framework is model-agnostic and is found to extend anti-forgetting capacity to any existing CIL backbone (e.g., LUCIR, PODNet, BiC), with typical gains in incremental accuracy and significant robustness against long-tail class imbalance (Hu et al., 2021).
A plausible implication is that future CIL approaches may generalize this principle to any structured multi-channel representation or hierarchical label schema, provided the appropriate causal dependencies can be identified and exploited through neighbor-conditioned or relation-preserving losses. The approach also suggests a theoretical unification of feature, logit, and data replay techniques through the SCM formalism, clarifying not only how forgetting arises but also how it can be reversed most effectively.
References
- "Distilling Causal Effect of Data in Class-Incremental Learning" (Hu et al., 2021)
- "Distilling Causal Effect from Miscellaneous Other-Class for Continual Named Entity Recognition" (Zheng et al., 2022)
- "CSTA: Spatial-Temporal Causal Adaptive Learning for Exemplar-Free Video Class-Incremental Learning" (Chen et al., 13 Jan 2025)