Memory-Efficient LLM Training with Gradient Low-Rank Projection
Introduction to Gradient Low-Rank Projection (\lowrank{})
Training LLMs poses significant memory challenges due to the large size of weights and optimizer states involved. Existing memory reduction techniques often involve low-rank adaptation methods, such as Low-Rank Adaptation (LoRA), which reparameterizes each layer's weight matrix as a sum of its original weight and a trainable low-rank matrix. Despite their efficacy in reducing the number of trainable parameters and associated optimizer states, these methods often underperform compared to full-rank training, especially in both pre-training and fine-tuning stages. This limitation is attributed mainly to the restrictive nature of low-rank parameterization and its alteration of training dynamics.
To address these challenges, we introduce Gradient Low-Rank Projection (\textbf{\lowrank{}}), a training strategy designed for both pre-training and fine-tuning LLMs that is more memory-efficient than traditional low-rank methods. Unlike LoRA, which directly imposes a low-rank structure on model weights, \lowrank{} capitalizes on the inherently low-rank structure of gradient updates during training. This strategy enables full-parameter learning while significantly reducing memory consumption.
Theoretical Insights and Methodology
Our work starts with demonstrating theoretically that the backpropagated gradient matrix becomes increasingly low-rank as training progresses. This insight leads to the core idea of \lowrank{}: projecting gradients into a low-rank subspace before applying optimizer updates. Specifically, for any weight update at time step , \lowrank{} projects the gradient onto matrices and , yielding a low-rank gradient form. Consequently, only the gradients' low-rank projections need to be stored in optimizer states, resulting in substantial memory savings.
Moreover, we provide a convergence analysis of \lowrank{} under certain gradient update forms, ensuring its effectiveness in both theoretical and practical settings. Importantly, \lowrank{} allows for dynamic adjustments of projection matrices during training, thus supporting full-parameter learning without increasing memory load.
Experimental Results
We thoroughly evaluate \lowrank{} on LLaMA-based models across different sizes, from 60M to 7B parameters, utilizing the C4 dataset for pre-training. Our findings indicate that \lowrank{} closely matches the performance of full-rank models while significantly reducing memory usage, proving its superiority over traditional low-rank adaptation methods like LoRA and ReLoRA. In particular, for a 7B parameter model, \lowrank{}, combined with 8-bit optimizer techniques and layer-wise weight updates, substantially outperforms full-rank training in memory efficiency without sacrificing training effectiveness.
Notably, the memory savings enabled by 8-bit \lowrank{} make it feasible to pre-train a 7B parameter model on consumer-level GPUs, such as NVIDIA RTX 4090, demonstrating its practical utility for large-scale LLM training within constrained memory environments.
Concluding Thoughts and Future Directions
\lowrank{} exemplifies a novel approach to memory-efficient training of LLMs by exploiting the low-rank structure of gradient updates. Its effectiveness in both pre-training and fine-tuning contexts signifies a notable advancement towards reducing the computational and environmental costs associated with LLM training. Looking forward, exploring further optimizations of \lowrank{}, including more memory-efficient projection matrices and its applicability to other model architectures and optimization strategies, presents promising avenues for continuing research in this area.