Deep Approximate Shapley Propagation: Evaluating and Approximating Shapley Values in Deep Neural Networks
The paper entitled "Explaining Deep Neural Networks with a Polynomial Time Algorithm for Shapley Values Approximation" addresses a critically important challenge in the field of Explainable AI: the reliable attribution of output predictions to input features within deep neural networks (DNNs). While various attribution methods exist, they frequently lack robust theoretical underpinnings, which casts doubt on their reliability. Within this context, Shapley values, a construct from cooperative game theory, emerge as a principled method for distributing credits across input features, but computing them exactly is NP-hard, rendering it infeasible for complex models with numerous features.
The authors propose a novel approach, Deep Approximate Shapley Propagation (DASP), to approximate Shapley values in a computationally feasible manner by leveraging uncertainty propagation techniques. DASP provides polynomial time complexity for calculating Shapley value approximations, significantly outperforming previous biased attribution methods when evaluated on approximation quality, and requiring fewer evaluations than unbiased sampling-based techniques to achieve similar accuracies.
Key Contributions and Methodology
This work makes three principal contributions:
- Theoretical Motivation: The authors rigorously argue for the superiority of Shapley values over existing attribution methods, based on a set of desirable axiomatic properties in the non-linear model regime. These include completeness, null player, symmetry, linearity, continuity, and implementation invariance. Shapley values uniquely satisfy all these axioms, making them an ideal candidate for reliable local explanations.
- Algorithm Design: DASP's innovation lies in approximating the Shapley values without assuming linearity, using a method inspired by the sequential propagation of expected values through a deep network architecture. By considering each feature's marginal contribution to random coalitions, computed via a streamlined propagation of mean and variance statistics through Lightweight Probabilistic Networks, DASP provides a practical approach to navigate the intractably large coalition space.
- Empirical Benchmarking: Through evaluations on tasks such as Parkinson's disability assessment, DNA sequence classification, and digit recognition (MNIST), DASP effectively approximated Shapley values. The comparisons indicated that DASP consistently outperformed alternative biased methods and required significantly fewer model evaluations than unbiased variants, like Shapley sampling and KernelSHAP, to reach similar approximation fidelity.
Implications and Future Directions
The introduction of DASP represents a meaningful advancement towards enhancing the interpretability of DNNs by providing reliable, theoretically grounded explanations for their outputs. Beyond theoretical appeal, the method scales effectively, promising practical utility in real-world applications where transparency and accountability are increasingly mandated, notably in finance and healthcare sectors governed by regulations such as the European Union's right to explanation.
Future research can further extend the applicability of DASP by integrating advances in uncertainty propagation, potentially expanding to recurrent neural networks and transfer learning. Moreover, the exploration of hybrid models that leverage both uncertainty propagation and sampling techniques may offer pathways to the next echelon of Shapley value approximation methods, potentially bridging the gap between exact theoretical guarantees and feasible computational strategies. This opens vistas for development towards universally robust and interpretability-focused AI tools.
In conclusion, the paper presents a significant step forward in explainable AI by detailing an efficient, reliable method for approximating Shapley values in DNNs, underscoring the importance of leveraging cooperative game theory principles to build trustworthy machine learning models.