MaskLLM: Learnable Semi-Structured Sparsity for LLMs
Introduction
The paper "MaskLLM: Learnable Semi-Structured Sparsity for LLMs" addresses the issue of computational overhead associated with the deployment of LLMs, which often come with massive parameter counts. Traditional approaches to manage this overhead have involved semi-structured pruning to introduce sparsity in LLMs. However, current methods rely heavily on heuristic importance criteria and small calibration datasets, limiting their scalability and accuracy. This paper proposes MaskLLM, a novel learnable pruning method that incorporates end-to-end training and leverages Gumbel Softmax sampling to optimize mask selection directly.
Methodology
MaskLLM introduces a probabilistic approach to mask selection, facilitating the learning of N:M sparsity patterns through gradient descent. Instead of employing hand-crafted criteria to determine parameter importance, MaskLLM models masks as learnable distributions, which enables the algorithm to adapt to large-scale datasets.
N:M Sparsity
The key idea is to enforce sparsity in a parameter block by retaining only N non-zero values among M consecutive parameters. For example, in a 2:4 sparsity pattern, each group of 4 parameters is pruned such that only 2 values remain non-zero. The problem is framed as a mask selection task across a candidate set of masks, which is optimized using Gumbel Softmax to ensure differentiability. This allows MaskLLM to perform end-to-end training and optimize the masks with regard to the LLMing loss directly.
Mask Selection and Learning Mechanism
The differentiable sampling mechanism enabled by Gumbel Softmax allows for effective mask exploration during training. By re-parameterizing the sampling process, the method ensures that mask probabilities are optimized using gradient descent, taking into account the actual dataset instead of relying on a small calibration set.
To mitigate gradient vanishing issues caused by zero pruning during training, a Sparse Weight Regularization term is introduced. This regularization maintains a sufficiently large magnitude in the remaining weights, facilitating better mask learning and transferability to downstream tasks.
Experimental Evaluation
Scalability and Accuracy
The efficacy of MaskLLM was evaluated on several models, including LLaMA-2, Nemotron-4, and GPT-3, with sizes ranging from 843M to 15B parameters. The experiments demonstrated that MaskLLM could effectively scale to large datasets, producing higher-quality masks compared to state-of-the-art methods like SparseGPT and Wanda. Specifically, for LLaMA-2 7B, MaskLLM achieved a perplexity (PPL) of 6.72 on Wikitext, significantly outperforming SparseGPT, which had a PPL of 10.42.
Mask Transferability
MaskLLM also introduces the concept of mask transferability, where learned masks can be adapted to different tasks or domains. The probabilistic nature of the mask selection process allows for the initialization of masks using pre-computed priors from one-shot methods, which are then refined through additional training. This feature was shown to provide substantial improvements in mask quality and training efficiency for downstream tasks.
Implications and Future Work
Practical Implications
By leveraging large-scale datasets and employing an end-to-end training paradigm, MaskLLM offers a robust solution for reducing the computational overhead of LLMs without sacrificing accuracy. This could lead to more efficient deployment of LLMs in real-world applications, particularly where computational resources are constrained.
Theoretical Implications
The probabilistic modeling of mask selection extends the scope of learnable sparsity patterns, offering a compelling alternative to traditional heuristic-based methods. This approach can be considered for other forms of model pruning and compression, potentially leading to new research directions in the field.
Conclusion
MaskLLM presents a novel approach to semi-structured pruning for LLMs, addressing scalability and accuracy issues inherent in current methods. Through probabilistic modeling and differentiable sampling, MaskLLM achieves high-quality mask learning and transferability, representing a significant advancement in the efficient deployment of LLMs. Future work could explore further optimizations to enhance training efficiency and generalize the approach to other sparsity patterns and model architectures.