Neural Additive Models: Interpretable Machine Learning with Neural Nets
The paper presents Neural Additive Models (NAMs), a novel approach aimed at achieving both interpretability and accuracy in machine learning models by combining the flexibility of deep neural networks with the transparency of Generalized Additive Models (GAMs). In order to tackle the common “black-box” problem associated with deep neural networks (DNNs), NAMs propose a model that preserves interpretability by learning a linear combination of neural networks, each corresponding to a single input feature. This is an ambitious attempt to leverage the strengths of DNNs while still maintaining the interpretability essential for applications in critical domains such as healthcare, finance, and criminal justice.
NAMs belong to the family of GAMs, with each component function of the model represented as a neural network that learns the contribution of individual input features independently. This structure provides a clear interpretability advantage: The influence of each feature on the output is separated, allowing for easy visualization and understanding of feature impacts.
Key Contributions
- Model Design and Training: NAMs introduce a model architecture where each feature is associated with a neural network that maps the feature to a shape function, which can be summed to predict the output. These networks are trained using backpropagation and incorporate differentiable elements, enabling the fitting of complex shape functions.
- Comparative Performance: The paper reports that NAMs achieve states of the comparable performance to existing GAMs with boosted trees, while being more flexible due to their neural network-based nature. This is supported by experiments across regression and classification tasks, showing that NAMs typically outperform simpler interpretable models like logistic regression and decision trees.
- Regularization Techniques: To ensure that the learned functions are neither overly smooth nor jumpy, strict regularization techniques are applied. This includes dropout, weight decay, output penalty, and feature dropout.
- Practical Application: The utility of NAMs is demonstrated in multitask learning and parameter generation problems, extending their usage to settings beyond traditional GAMs. For example, NAMs can be flexibly applied to multitask learning for datasets like COMPAS, elucidating separate relationships for different demographic groups.
- Interpretability and Flexibility: NAMs foster interpretability by making the exact decision-making process of the model transparent. Each shape plot serves as an exact representation of how predictions are made, which is critical in high-stakes applications.
Implications and Future Directions
The introduction of NAMs signals an important step towards making neural networks interpretable, thereby broadening their applicability in trust-critical fields. Their differentiable and modular architecture allows NAMs to serve as components in larger neural networks while retaining interpretability. Moreover, NAMs could potentially be integrated with other deep learning paradigms to create hybrid models that balance accuracy and intelligibility.
Future research could focus on enhancing the expressivity of NAMs by efficiently incorporating interactions between features, evaluating their performance on more complex datasets, especially in domains like computer vision or natural language processing, where feature interpretability is often more challenging. Also, exploring non-standard activation functions and initializations like ExU units which have shown promise in learning jumpy functions could widen the applicability of NAMs.
In conclusion, NAMs provide a promising template for interpretable neural networks, blending flexibility, transparency, and state-of-the-art performance in a design that respects the needs of end-users who demand interpretability without sacrificing the robust functionality of neural networks.