Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
119 tokens/sec
GPT-4o
56 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

MaskLLM: Learnable Semi-Structured Sparsity for Large Language Models (2409.17481v2)

Published 26 Sep 2024 in cs.AI, cs.CL, and cs.LG

Abstract: LLMs are distinguished by their massive parameter counts, which typically result in significant redundancy. This work introduces MaskLLM, a learnable pruning method that establishes Semi-structured (or ``N:M'') Sparsity in LLMs, aimed at reducing computational overhead during inference. Instead of developing a new importance criterion, MaskLLM explicitly models N:M patterns as a learnable distribution through Gumbel Softmax sampling. This approach facilitates end-to-end training on large-scale datasets and offers two notable advantages: 1) High-quality Masks - our method effectively scales to large datasets and learns accurate masks; 2) Transferability - the probabilistic modeling of mask distribution enables the transfer learning of sparsity across domains or tasks. We assessed MaskLLM using 2:4 sparsity on various LLMs, including LLaMA-2, Nemotron-4, and GPT-3, with sizes ranging from 843M to 15B parameters, and our empirical results show substantial improvements over state-of-the-art methods. For instance, leading approaches achieve a perplexity (PPL) of 10 or greater on Wikitext compared to the dense model's 5.12 PPL, but MaskLLM achieves a significantly lower 6.72 PPL solely by learning the masks with frozen weights. Furthermore, MaskLLM's learnable nature allows customized masks for lossless application of 2:4 sparsity to downstream tasks or domains. Code is available at https://github.com/NVlabs/MaskLLM.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (8)
  1. Gongfan Fang (33 papers)
  2. Hongxu Yin (49 papers)
  3. Saurav Muralidharan (14 papers)
  4. Greg Heinrich (12 papers)
  5. Jeff Pool (11 papers)
  6. Jan Kautz (215 papers)
  7. Pavlo Molchanov (70 papers)
  8. Xinchao Wang (203 papers)

Summary

MaskLLM: Learnable Semi-Structured Sparsity for LLMs

Introduction

The paper "MaskLLM: Learnable Semi-Structured Sparsity for LLMs" addresses the issue of computational overhead associated with the deployment of LLMs, which often come with massive parameter counts. Traditional approaches to manage this overhead have involved semi-structured pruning to introduce sparsity in LLMs. However, current methods rely heavily on heuristic importance criteria and small calibration datasets, limiting their scalability and accuracy. This paper proposes MaskLLM, a novel learnable pruning method that incorporates end-to-end training and leverages Gumbel Softmax sampling to optimize mask selection directly.

Methodology

MaskLLM introduces a probabilistic approach to mask selection, facilitating the learning of N:M sparsity patterns through gradient descent. Instead of employing hand-crafted criteria to determine parameter importance, MaskLLM models masks as learnable distributions, which enables the algorithm to adapt to large-scale datasets.

N:M Sparsity

The key idea is to enforce sparsity in a parameter block by retaining only N non-zero values among M consecutive parameters. For example, in a 2:4 sparsity pattern, each group of 4 parameters is pruned such that only 2 values remain non-zero. The problem is framed as a mask selection task across a candidate set of masks, which is optimized using Gumbel Softmax to ensure differentiability. This allows MaskLLM to perform end-to-end training and optimize the masks with regard to the LLMing loss directly.

Mask Selection and Learning Mechanism

The differentiable sampling mechanism enabled by Gumbel Softmax allows for effective mask exploration during training. By re-parameterizing the sampling process, the method ensures that mask probabilities are optimized using gradient descent, taking into account the actual dataset instead of relying on a small calibration set.

To mitigate gradient vanishing issues caused by zero pruning during training, a Sparse Weight Regularization term is introduced. This regularization maintains a sufficiently large magnitude in the remaining weights, facilitating better mask learning and transferability to downstream tasks.

Experimental Evaluation

Scalability and Accuracy

The efficacy of MaskLLM was evaluated on several models, including LLaMA-2, Nemotron-4, and GPT-3, with sizes ranging from 843M to 15B parameters. The experiments demonstrated that MaskLLM could effectively scale to large datasets, producing higher-quality masks compared to state-of-the-art methods like SparseGPT and Wanda. Specifically, for LLaMA-2 7B, MaskLLM achieved a perplexity (PPL) of 6.72 on Wikitext, significantly outperforming SparseGPT, which had a PPL of 10.42.

Mask Transferability

MaskLLM also introduces the concept of mask transferability, where learned masks can be adapted to different tasks or domains. The probabilistic nature of the mask selection process allows for the initialization of masks using pre-computed priors from one-shot methods, which are then refined through additional training. This feature was shown to provide substantial improvements in mask quality and training efficiency for downstream tasks.

Implications and Future Work

Practical Implications

By leveraging large-scale datasets and employing an end-to-end training paradigm, MaskLLM offers a robust solution for reducing the computational overhead of LLMs without sacrificing accuracy. This could lead to more efficient deployment of LLMs in real-world applications, particularly where computational resources are constrained.

Theoretical Implications

The probabilistic modeling of mask selection extends the scope of learnable sparsity patterns, offering a compelling alternative to traditional heuristic-based methods. This approach can be considered for other forms of model pruning and compression, potentially leading to new research directions in the field.

Conclusion

MaskLLM presents a novel approach to semi-structured pruning for LLMs, addressing scalability and accuracy issues inherent in current methods. Through probabilistic modeling and differentiable sampling, MaskLLM achieves high-quality mask learning and transferability, representing a significant advancement in the efficient deployment of LLMs. Future work could explore further optimizations to enhance training efficiency and generalize the approach to other sparsity patterns and model architectures.

Github Logo Streamline Icon: https://streamlinehq.com

GitHub