Enhancing Neural Networks with Multi-Axis Query Sparsity
Introduction to N:M Sparsity
Deep convolutional neural networks (CNNs) have made tremendous strides in various computer vision tasks. Yet, their widespread deployment is often hampered by high memory and computational demands, which are particularly challenging for mobile or edge devices. Network sparsity has emerged as an effective solution to this problem by offering memory and computation savings. Even among the sparse network techniques, the N:M sparsity pattern has drawn increased interest because of its ability to balance performance and latency effectively. N:M sparsity involves keeping only N out of every M consecutive weights within the network, promoting a fine-grained structure of sparsity.
Despite its promise, previous methods for implementing N:M sparsity didn't fully exploit the relative importance of different weights within the neural network blocks, leading to sub-optimal performance.
MaxQ: A Multi-Axis Query Approach
The paper introduces a Multi-Axis Query methodology, named MaxQ, designed to address the limitations of prior N:M sparsity implementations. Unlike earlier methods that considered weights within N:M blocks independently, MaxQ can assess the importance of weights across multiple axes to identify more significant connections within a network.
MaxQ operates dynamically, generating 'soft' N:M masks throughout the training process. These masks highlight and prioritize the alignment of weight updates to the more significant weights, ensuring that crucial connective weights are not undervalued. A particularly innovative aspect of MaxQ is its sparsity strategy: gradually increasing the proportion of weight blocks adhering to the N:M sparsity pattern as the training progresses. This incremental approach lets the network gradually recover from the impact of initial pruning, leading to more stable and efficient training.
During runtime, these soft N:M masks can be precomputed and folded into the weights, posing no additional computational strain nor disrupting the sparse pattern during inference.
Comprehensive Evaluation
The effectiveness of MaxQ was put to the test across different CNN architectures and sparsity patterns. Enhancements were consistent, with substantial improvements particularly notable in heavyweight CNN architectures like ResNet. For example, MaxQ managed to push a 1:16 sparse ResNet50 model to a top-1 accuracy of 74.6% on ImageNet, improving upon the prior best by 2.8%.
Moreover, MaxQ's multi-axis soft masking approach also proved to be beneficial for downstream tasks beyond image classification, such as object detection and instance segmentation, even matching the performance of non-sparse baseline models.
Advantages and Practical Implications
MaxQ is not solely about achieving high compression in neural networks. Importantly, the method's flexibility is underscored as it is applied to varying N:M sparse patterns without significant modifications. Additionally, MaxQ networks can be implemented directly during training without reliance on iterative pre-training or fine-tuning stages that previous methods required, simplifying the process.
Finally, MaxQ shows surprising compatibility with quantization methods, even outperforming some predecessors. When combined with its ability to maintain the N:M sparsity structure, MaxQ represents a leap forward in optimizing networks for deployment on resource-constrained devices.
Conclusion
The Multi-Axis Query technique presents an interpretable, effective, and efficient means of exploiting N:M sparsity within CNNs. Its dynamic query-based mask generation and the progressive adaptation of sparsity throughout training lead to neural networks that retain high performance while meeting stringent resource constraints. MaxQ's demonstration across various computer vision tasks and CNN architectures may set a standard for future sparsity-based network optimizations.