Shapley Transform in Neural Networks
- Shapley Transform is a formal mechanism that embeds exact Shapley attributions as intrinsic, differentiable representations within neural network architectures.
- It enables layerwise computation of feature contributions enabling efficient test-time attribution, explanation regularization, and dynamic pruning.
- Shallow and Deep ShapNets leverage classical SHAP axioms to ensure local accuracy, missingness, and consistency while avoiding post-hoc computational overhead.
The Shapley transform is a formal mechanism for embedding exact Shapley attributions as intrinsic, differentiable representations within neural network architectures. It provides a principled framework to compute and utilize Shapley values layerwise, enabling explanation regularization, efficient test-time attribution, and dynamic pruning. Operationalized as Shapley modules and composed into Shallow and Deep ShapNets, the Shapley transform leverages both classical SHAP axioms and modern deep learning constructs, ensuring local accuracy, missingness, and scalability while avoiding the intractable overhead of post-hoc SHAP computation (Wang et al., 2021).
1. Mathematical Definition of the Shapley Transform
Let be an input tensor, with explainable dimensions and channel dimensions. For a given vector-valued function , with component functions , the Shapley transform is defined as
where is the classical Shapley value of the th feature for the scalar function . The resulting tensor is referred to as the Shapley representation of under .
Local accuracy, missingness, and consistency—SHAP axioms established by Lundberg and Lee—are preserved in this construction. Specifically:
- Local Accuracy: for scalar ,
- Missingness: If feature is absent (set to reference ), then ,
- Consistency: If the marginal contribution of a feature does not decrease for all subsets, the corresponding Shapley value cannot decrease.
A foundational lemma establishes the linearity of the Shapley transform: For , with . Thus, linear mixing of Shapley channels yields valid Shapley representations for composed functions.
2. Shapley Modules: Neural Implementation of the Transform
A Shapley module is a neural network block realizing the Shapley transform for a scalar target function of small active set , i.e., . The outputs are
using a reference tensor . By restricting each to few inputs, the complexity for computing exact Shapley values becomes practical for –$4$. This enables differentiable, tractable attribution propagation as an intrinsic component of the model, in contrast to expensive post-hoc SHAP estimators.
3. Shallow ShapNets: Exact Attributions in a Forward Pass
A Shallow ShapNet is constructed by composing a bank of Shapley modules (one per output channel) followed by a sum over the explainable dimensions: with representing the Shapley values for each output. In this setting:
- The output prediction equals , and the internal representation matches the exact Shapley attributions.
- All SHAP properties (local accuracy, missingness, consistency) are satisfied identically.
- Expressive coverage of all features is achieved by aggregating pairs or small subsets in early modules, then linearly mixing, preserving expressivity by the linearity lemma.
4. Deep ShapNets: Layerwise Attribution and Dynamic Pruning
A Deep ShapNet stacks Shapley transforms: with final output
Each layer consists of parallel Shapley modules (with small active sets) and optional learned linear mixing. A canonical “butterfly” (FFT-style) schedule organizes disjoint pairs per layer, ensuring that all pairwise feature interactions are eventually captured for layers.
Key properties include:
- Local accuracy: Achieved in the final prediction via telescoping of Shapley sums.
- Missingness: If any feature’s attribution becomes zero at a layer, it is guaranteed to remain zero in all subsequent layers (Corollary 4.1), facilitating input-adaptive, dynamic pruning at inference.
5. Complexity, Efficiency, and Training Regularization
Computational Complexity and Efficiency
- Per-module cost: forward passes for Shapley value computation on the active set.
- End-to-end model: A single forward pass through all Shapley modules yields both prediction and all internal explanations, matching the runtime of conventional DNNs of similar depth and width.
- Test-time dynamic pruning: Features for which attributions become zero at any layer can be pruned in all downstream computations. This input-adaptive pruning reduces computation substantially for sparse or irrelevant features (Corollary 4.1).
Regularization via Shapley Representations
Since each intermediate tensor is itself a Shapley attribution, regularization can be directly imposed on the intrinsic attributions. For example: Such regularizers enforce sparsity or smoothness in the layerwise Shapley values, directly shaping attribution characteristics during training, rather than relying on smoothing post-hoc saliency maps.
6. Implications and Distinction from Post-hoc SHAP Approaches
The Shapley transform framework makes per-feature attributions intrinsic to the model, rather than an external analysis step. Unlike post-hoc SHAP methods (e.g., DeepSHAP, KernelSHAP) which require expensive sampling or backward passes, Shallow and Deep ShapNets compute attributions concurrently with predictions. This design provides:
- Instantaneous computation of additive (Shallow case) or approximate (Deep case) SHAP attributions for arbitrary inputs,
- Full support for attribution-based regularization in gradient-based optimization,
- Input-specific dynamic computational savings via layerwise pruning,
- Consistent satisfaction of SHAP properties (local accuracy, missingness, consistency) at every or most layers.
A plausible implication is more interpretable architectures for domains where feature attributions and compliance with explainability standards are critical, without sacrificing efficiency, trainability, or expressivity (Wang et al., 2021).