Layerwise Importance Sampled AdamW (LISA)
- The paper introduces LISA, a novel fine-tuning method that selects important layers via importance sampling to achieve on-par or superior performance compared to full-parameter AdamW.
- It employs selective freezing of layers to dramatically reduce GPU memory usage by updating only a dynamically chosen subset during backpropagation.
- Empirical evaluations show that LISA converges faster and boosts performance metrics in tasks like MT-Bench, GSM8K, and PubMedQA across various LLMs.
Layerwise Importance Sampled AdamW (LISA) is a memory-efficient fine-tuning strategy for LLMs that leverages the empirical skewness of weight-norm changes across model layers during adaptation. Unlike Low-Rank Adaptation (LoRA), which inserts learnable low-rank adapters at each layer, LISA applies importance sampling to select a small, dynamically changing subset of layers for AdamW optimization, randomly freezing the remainder. This approach maintains or exceeds the fine-tuning performance of both LoRA and full-parameter AdamW, while matching or reducing memory requirements associated with optimizer state and parameter updates (Pan et al., 2024).
1. Background and Motivation
Full-parameter AdamW fine-tuning of LLMs requires substantial GPU memory, as it necessitates storing all gradients, first and second moment buffers per parameter, and activations. For example, a 7B parameter model typically needs at least 60 GB of GPU memory, posing a barrier for researchers without access to large hardware resources.
LoRA significantly reduces trainable parameters by introducing low-rank adapters into each linear layer. However, in many large-scale settings, such as continual pre-training and instruction tuning, LoRA can underperform relative to full fine-tuning, as its parameter search is confined to a low-rank subspace. Detailed examination reveals that LoRA’s layerwise weight-norm changes are highly skewed: only the embedding and head layers exhibit substantial updates, while intermediate self-attention blocks receive minimal changes. By contrast, full-parameter tuning yields more uniform layer updates.
This consistent skewness motivates a strategy that prioritizes “important” layers—those experiencing larger updates—by allocating computational resources preferentially to them while freezing less critical layers. The LISA algorithm operationalizes this insight through stochastic, importance-driven parameter updates.
2. The LISA Algorithm: Core Mechanism
LISA (Layerwise Importance Sampled AdamW) is defined by layerwise importance sampling and selective freezing within the AdamW optimization framework. The algorithm proceeds in the following steps:
- For a model with transformer layers (including embedding and head), total iterations , sampling interval , and importance sampling probabilities , initialize parameters, and AdamW moment buffers.
- At every steps, sample an “active set” of layers according to the fixed distribution , with special treatment to always include embedding and head layers. The complement set is frozen.
- Conduct forward computation through all layers for activation purposes, but during backpropagation, set gradients of frozen layers to zero: for all .
- Apply AdamW updates only to active layers. For each :
- Frozen layers maintain their current parameter values and moment estimates.
Formally, with active set and frozen set :
3. Importance Sampling for Layer Selection
A critical component of LISA is the construction of the importance sampling distribution over layers. The theoretical importances can be formulated as either the L2 norm of layer parameters, , or the expected L2 norm of the gradient, . These are normalized to derive sampling probabilities:
Empirical investigation reveals dominance of embedding and head layer norms, so their probabilities are fixed at 1.0. The remaining probability is distributed uniformly among selected intermediate layers per interval. For practical purposes, exactly non-embedding, non-head layers are selected at each steps, while embedding/head are always active.
4. Memory Complexity and Efficiency
Let denote the number of scalar parameters, the layer count, and the LoRA rank per layer. The memory usage patterns for competing approaches are:
| Method | Parameter Size | Optimizer State | Activations | Adapter Overhead |
|---|---|---|---|---|
| AdamW (full tuning) | $2D$ | proportional to | None | |
| LoRA (rank ) | + $2D r$ | $2D r$ (for adapter states) | proportional to | $2D r$ |
| LISA (with active) | (for active layers only) | proportional to | None |
For , LISA’s optimizer-state memory is reduced to , significantly less than the $2D$ required by full AdamW and generally smaller than LoRA’s adapter overhead. Empirical results indicate that, with typical settings (–$256$, –$4$), LISA uses within 5–10% of LoRA’s peak memory (see Table 1 of (Pan et al., 2024)).
5. Empirical Evaluation
Experiments evaluate LISA, LoRA, and full-tuning across models including GPT2-Small, TinyLlama (1.1B), Phi-2 (2.7B), Mistral-7B, LLaMA-2-7B, and LLaMA-2-70B. Tasks encompass instruction following (Alpaca GPT-4 finetuning, measured by MT-Bench), mathematics (GSM8K), and medical QA (PubMedQA). Key findings include:
- MT-Bench (LLaMA-2-7B): LISA (, ) achieves 5.42, surpassing full-tuning (5.18) and LoRA (, 4.86), for improvements of +11% vs LoRA and +4.6% vs full-tuning.
- MT-Bench (LLaMA-2-70B): LISA (, ) yields 7.05, exceeding full-tuning (6.66) and LoRA (6.52).
- GSM8K (LLaMA-2-70B): Accuracy improves from 59.4% (LoRA) to 61.1% (LISA).
- PubMedQA (LLaMA-2-70B): Accuracy increases from 90.8% (LoRA) to 91.6% (LISA).
- Gains are larger on smaller models: TinyLlama MT-Bench average rises from 2.03 (LoRA) to 2.78 (LISA, +37%); Mistral-7B from 4.71 (LoRA) to 5.23 (LISA, +11%).
- LISA converges faster, as shown in training loss trajectories, and can outperform full-tuning in aspects sensitive to alignment, such as writing and humanities.
6. Ablation and Sensitivity Analysis
Performance trade-offs are analyzed with respect to the number of active layers () and the sampling interval ():
- Increasing leads to improved MT-Bench scores but higher memory consumption.
- Reducing (more frequent re-sampling) accelerates convergence up to an optimal point.
- Sensitivity to sampling randomness is minimal: across three random seeds, MT-Bench score variance is ≤0.13.
7. Implementation and Practical Considerations
LISA is compatible with any PyTorch-style training loop by zeroing gradients for frozen layers or toggling on parameters. Recommended hyperparameters:
- Learning rate: for LISA (and LoRA) on 1B–7B models; for full-tuning.
- Number of active layers: (7B), (70B).
- Sampling interval: between 3–10 is effective; up to for large-scale runs.
- AdamW settings: , , , weight-decay=0.1.
- Fixed probabilities for embedding/head layers; uniform allocation for remaining layers.
- Can be combined with DeepSpeed ZeRO-Offload or inference-time quantization (e.g., QLoRA) for additional memory savings.
A plausible implication is that LISA constitutes a tractable alternative to adapter-based or full-parameter fine-tuning frameworks for LLMs, particularly in GPU-limited environments. By exploiting skewed utility across transformer layers, it achieves improved or on-par downstream performance with materially reduced memory footprint (Pan et al., 2024).