- The paper introduces Marill, a framework that reduces MPC computational burdens for LLM inference, achieving 3.6–11.3× runtime and 2.4–6.9× communication improvements.
- The methodology strategically shifts expensive operations outside of MPC by employing layer freezing, low-rank adaptation, and head merging during fine-tuning.
- The approach maintains high performance and security, paving the way for practical deployment of privacy-preserving AI in sensitive fields like healthcare and finance.
Essay on MPC-Minimized Secure LLM Inference
The paper "MPC-Minimized Secure LLM Inference" by Deevashwer Rathee et al. presents a novel framework, Marill, which is designed to optimize the process of secure inference for LLMs through the application of secure multi-party computation (MPC).
With the proliferation of transformer-based LLMs, such as GPT-4, Llama, and Mistral, there are significant deployments and usage of these models in real-world applications like chatbots and virtual assistants. However, the integration of these LLMs raises substantial privacy concerns. These privacy concerns stem from the necessity to keep the model weights confidential, as they are often proprietary, while also safeguarding user inputs, which may contain sensitive information. Prior solutions using secure MPC have been unable to overcome the prohibitive overheads associated with secure inference of LLMs, hence making practical deployment challenging.
In addressing the efficiency problems associated with MPC, Marill strategically reduces the burden of computation within MPC while maintaining the same level of security. Unlike previous approaches that attempt to approximate or simplify operations for computational efficiency, Marill implements high-level architectural changes. Specifically, it optimizes the fine-tuning process of LLMs to shift expensive operations outside of MPC. This results in markedly reduced operational costs related to secure inference. The paper reports that Marill-enhanced models achieve between 3.6 to 11.3 times better runtime and 2.4 to 6.9 times better communication during secure inference in various MPC settings, compared to standard fine-tuned models, with performance usually above 90% across downstream tasks.
A key insight driving Marill is the adaptation of the LLMs' architecture during fine-tuning which facilitates MPC-minimization. Marill employs three major strategies for this purpose: Layer Freezing, Low-rank Adaptation (LoRA), and Head Merging.
- Layer Freezing: This approach involves limiting fine-tuning updates to only the top layers of LLMs. By confining changes to these layers, a substantial portion of computations concerning public pre-trained weights can be performed outside MPC, thus achieving reduced communication and computation costs.
- Low-rank Adaptation (LoRA): This method aims to capitalize on prior knowledge embedded in pre-trained model weights and limit the adaptation process to significantly fewer parameters. Low-rank adaptation aids in reducing the dimensions involved in matrix multiplications within MPC, tackling a significant runtime bottleneck.
- Head Merging: Beyond the straightforward pruning of attention heads, head merging combines multiple attention heads while preserving all model parameters, thereby maintaining model accuracy. This technique addresses the computational overhead, notably of self-attention components, which typically scale quadratically with sequence length.
The implications of Marill are significant, representing a practical means to deploy privacy-preserving AI services without undue computation costs. The fact that Marill generates models with public and private components represents an evolution in how we understand and operationalize secure inference. This approach potentially broadens the scope of use for LLMs in privacy-sensitive environments such as healthcare or finance, where user queries involve proprietary or protected information and must remain confidential.
Looking towards future development, the paper mentions that combining Marill with MPC-friendly approximations, such as ReLU-based approximations for non-linear activations, can yield further improvements. Also, the methodology opens avenues for optimizing other forms of secure computation by transparently leveraging public model structures. The research thus represents a notable contribution to both the theory and practice of securely deploying machine learning technologies.