Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
139 tokens/sec
GPT-4o
47 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Captum: A unified and generic model interpretability library for PyTorch (2009.07896v1)

Published 16 Sep 2020 in cs.LG, cs.AI, and stat.ML

Abstract: In this paper we introduce a novel, unified, open-source model interpretability library for PyTorch [12]. The library contains generic implementations of a number of gradient and perturbation-based attribution algorithms, also known as feature, neuron and layer importance algorithms, as well as a set of evaluation metrics for these algorithms. It can be used for both classification and non-classification models including graph-structured models built on Neural Networks (NN). In this paper we give a high-level overview of supported attribution algorithms and show how to perform memory-efficient and scalable computations. We emphasize that the three main characteristics of the library are multimodality, extensibility and ease of use. Multimodality supports different modality of inputs such as image, text, audio or video. Extensibility allows adding new algorithms and features. The library is also designed for easy understanding and use. Besides, we also introduce an interactive visualization tool called Captum Insights that is built on top of Captum library and allows sample-based model debugging and visualization using feature importance metrics.

Citations (744)

Summary

  • The paper introduces Captum, a generic model interpretability library for PyTorch that integrates diverse attribution methods like Integrated Gradients, DeepLift, and SHAP.
  • It supports multiple attribution levels—including input features, neurons, and layers—to provide detailed insights into neural network decision processes.
  • The library emphasizes scalability with GPU parallel processing and robust evaluation metrics, ensuring reliable interpretability across various applications.

Captum: A Unified and Generic Model Interpretability Library for PyTorch

The paper, "Captum: A unified and generic model interpretability library for PyTorch," authored by Narine Kokhlikyan et al. introduces Captum, an open-source library dedicated to enhancing the interpretability of models built using PyTorch. Captum integrates multiple gradient and perturbation-based attribution algorithms, catering to a broad spectrum of neural network models, including graph-structured networks.

Model Interpretability

One of the significant challenges in modern machine learning is the interpretability of complex neural networks (NNs). Given the black-box nature of these models, understanding their internal workings is paramount, especially in critical applications such as healthcare, finance, and autonomous systems. Traditional NNs are often opaque, requiring robust post-hoc interpretability techniques to illuminate their decision-making processes. Captum addresses this demand by providing a suite of attribution algorithms that elucidate feature, neuron, and layer importance.

Supported Attribution Algorithms

Captum categorizes its attribution techniques into primary, neuron, and layer attributions. Each category serves a different aspect of model interpretation:

  1. Primary attributions: Focuses on attributing model outputs to input features.
  2. Neuron attributions: Attributes the significance of hidden neurons back to the input features.
  3. Layer attributions: Distributes output significance across an entire hidden layer.

These algorithms encompass well-known methods such as Integrated Gradients, DeepLift, SHAP variants, and Occlusion. Techniques like GradCam and GuidedBackProp, traditionally associated with computer vision tasks, are also implemented generically within Captum, making them applicable beyond convolutional models.

Scalability and Efficiency

One of Captum’s notable features is its emphasis on scalability and efficient computation. It accommodates large input sizes and extensive parameter sets by dividing inputs into smaller chunks and performing computations serially, thereby preventing memory overflow. The support for PyTorch's DataParallel further enhances computational efficiency by enabling simultaneous processing across multiple GPUs. Empirical results demonstrate significant reduction in execution time as the number of GPUs increases, showcasing Captum's ability to leverage hardware resources effectively.

Evaluation Metrics

Evaluating the relevance and accuracy of attribution methods is a non-trivial task. Visual explanations can often be deceptive, and quantitative metrics may be domain-specific and subjective. To mitigate these challenges, Captum incorporates two quantitative evaluation metrics: infidelity and maximum sensitivity. Infidelity measures how well the attributions approximate the change in model output when the input is perturbed. Maximum sensitivity assesses the robustness of attributions by evaluating the maximum change in attributions under perturbations within a specified radius. These metrics provide a structured approach to evaluate the reliability of interpretability techniques.

Practical Applications

Captum’s versatility extends to various applications, ranging from text classification to regression and multi-modal neural networks. For text classification on the IMDB dataset, Integrated Gradients effectively identifies salient tokens contributing to the model's predictions. In regression tasks, layer conductance can attribute outputs in models predicting, for instance, Boston house prices, ensuring that the attributions align with model weights. Furthermore, Captum’s multi-modal support allows the interpretation of models that incorporate multiple types of data, such as text and images in Visual Question Answering systems.

Captum Insights

To facilitate model debugging and visualization, Captum includes an interactive tool called Captum Insights. This tool allows users to sub-sample inputs, apply various attribution algorithms, and visualize the results interactively. In multi-modal models, it can even aggregate feature importances across different modalities, providing a comprehensive view of the model’s decision-making process.

Conclusion and Future Directions

Captum represents a significant step towards democratizing model interpretability within the PyTorch ecosystem. Its unified interface, extensive algorithmic support, and robust evaluation metrics make it suitable for both research and production environments. Future directions for Captum development include the expansion of attribution techniques, integration with adversarial robustness research, and exploration of global model interpretability through concept-based explanations. Additionally, visualizing high-dimensional embeddings and understanding the significance of individual neurons in latent layers remain promising areas for further investigation.

References

  • [Paszke et al., 2019]
  • [Ancona et al., 2018]
  • [Yeh et al., 2019]

Note: The above essay provides a condensed technical summary suitable for an audience acquainted with the domain, adhering to a professional and academic tone while emphasizing key findings and implications of the paper.