Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
119 tokens/sec
GPT-4o
56 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Long-Tailed Classification by Keeping the Good and Removing the Bad Momentum Causal Effect (2009.12991v4)

Published 28 Sep 2020 in cs.CV, cs.LG, and stat.ML

Abstract: As the class size grows, maintaining a balanced dataset across many classes is challenging because the data are long-tailed in nature; it is even impossible when the sample-of-interest co-exists with each other in one collectable unit, e.g., multiple visual instances in one image. Therefore, long-tailed classification is the key to deep learning at scale. However, existing methods are mainly based on re-weighting/re-sampling heuristics that lack a fundamental theory. In this paper, we establish a causal inference framework, which not only unravels the whys of previous methods, but also derives a new principled solution. Specifically, our theory shows that the SGD momentum is essentially a confounder in long-tailed classification. On one hand, it has a harmful causal effect that misleads the tail prediction biased towards the head. On the other hand, its induced mediation also benefits the representation learning and head prediction. Our framework elegantly disentangles the paradoxical effects of the momentum, by pursuing the direct causal effect caused by an input sample. In particular, we use causal intervention in training, and counterfactual reasoning in inference, to remove the "bad" while keep the "good". We achieve new state-of-the-arts on three long-tailed visual recognition benchmarks: Long-tailed CIFAR-10/-100, ImageNet-LT for image classification and LVIS for instance segmentation.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (3)
  1. Kaihua Tang (13 papers)
  2. Jianqiang Huang (62 papers)
  3. Hanwang Zhang (161 papers)
Citations (406)

Summary

A Causal Framework for Enhancing Long-Tailed Classification

This paper introduces a novel approach for tackling the long-standing problem of long-tailed classification in deep learning. The goal is to move beyond heuristic methods and provide a theoretically grounded solution to manage class imbalance effectively. Central to this work is the employment of causal inference to unravel and address the underlying biases introduced during model training, particularly those related to the Stochastic Gradient Descent (SGD) momentum.

Problem Formulation and Proposed Solution

Long-tailed classification addresses the challenge posed by datasets where a few classes (head classes) have plenty of samples, while most classes (tail classes) are severely under-represented. Existing solutions predominantly rely on re-sampling or re-weighting techniques, which do not consistently capture the complexities introduced by such imbalances.

This research introduces a causal inference framework that fundamentally analyzes the role of SGD momentum as a confounding factor in model training. The SGD momentum inadvertently favors head classes during optimization, leading to biased predictions. However, it also aids in representation learning, particularly for head classes. The authors propose a causal graph representing the relationships between momentum, feature vectors, and classification outcomes.

The key insight from the paper involves decomposing the effects of momentum into a "good" and "bad" component. The "bad" component arises from confounding induced by the momentum, while the "good" component is its role as a mediator in representation learning. The authors propose using causal interventions during training to mitigate the confounding impact while leveraging counterfactual reasoning during inference to isolate and utilize the beneficial aspects of momentum-induced mediation. This leads to what is termed the Total Direct Effect (TDE) inference.

Implementation and Results

The proposed solution involves two primary phases: de-confounded training and TDE inference. During training, causal interventions are employed to suppress confounding effects, allowing the model to learn more balanced feature representations. The TDE inference is employed at the prediction stage, subtracting the confounding effects to deliver unbiased class predictions.

This methodology was validated on several benchmarks - Long-tailed CIFAR-10/-100, ImageNet-LT, and LVIS. Impressively, the method achieved state-of-the-art performance across these datasets, significantly improving recognition capabilities on under-represented classes without sacrificing performance on head classes.

Implications and Future Work

The framework offers a principled way to enhance classifiers in imbalanced scenarios, demonstrating its utility in both theoretical understanding and practical application. This work significantly advances the field by providing an analytical lens to dissect and rectify bias—crucial for training fair and equitable machine learning models, especially in AI applications with societal implications (e.g., facial recognition, autonomous vehicles).

Future developments could explore further refinements to the causal graph model, incorporation in more complex architectures like transformers, and exploration in non-visual domains. Additionally, extending this framework to online learning environments, where data distribution can shift continuously, will further underline its robustness and adaptability.

Conclusion

The paper presents a compelling case for the integration of causal inference in tackling long-tailed classification, balancing theoretical rigor with empirical success. It sets a foundational precedent for future work in ensuring fair and unbiased model training amidst inherently unbalanced real-world datasets.