Mitigating the Alignment Tax of RLHF (2309.06256v4)
Abstract: LLMs acquire a wide range of abilities during pre-training, but aligning LLMs under Reinforcement Learning with Human Feedback (RLHF) can lead to forgetting pretrained abilities, which is also known as the alignment tax. To investigate alignment tax, we conducted experiments with existing RLHF algorithms using OpenLLaMA-3B, which revealed a pronounced alignment tax in NLP tasks. Whereas, despite various techniques to mitigate forgetting, they are often at odds with the RLHF performance, leading to a trade-off between alignment performance and forgetting mitigation, leading to an alignment-forgetting trade-off. In this paper we show that model averaging, which simply interpolates between pre and post RLHF model weights, surprisingly achieves the most strongest alignment-forgetting Pareto front among a wide range of competing methods. To understand its effectiveness, we offer theoretical insights into model averaging, revealing that it enhances performance Pareto front by increasing feature diversity on the layers where tasks share overlapped feature spaces. Empirical evidence corroborates our analysis by showing the benefits of averaging low-level transformer layers. Building on the analysis and the observation that averaging different layers of the transformer leads to significantly different alignment-forgetting trade-offs, we propose Heterogeneous Model Averaging (HMA) to Heterogeneously find various combination ratios of model layers. HMA seeks to maximize the alignment performance while incurring minimal alignment tax. Moreover, we validate HMA's performance across a range of RLHF algorithms over OpenLLaMA-3B and further extend our findings to Mistral-7B which is evaluated by open-sourced preference model and GPT4. Code available here: https://github.com/avalonstrel/Mitigating-the-Alignment-Tax-of-RLHF.git.
- Memory aware synapses: Learning what (not) to forget. In Proceedings of the European conference on computer vision (ECCV), pages 139–154, 2018.
- The evolution of out-of-distribution robustness throughout fine-tuning. arXiv preprint arXiv:2106.15831, 2021.
- A general language assistant as a laboratory for alignment. arXiv preprint arXiv:2112.00861, 2021.
- Objectnet: A large-scale bias-controlled dataset for pushing the limits of object recognition models. Advances in neural information processing systems, 32, 2019.
- Piqa: Reasoning about physical commonsense in natural language. In Proceedings of the AAAI conference on artificial intelligence, volume 34, pages 7432–7439, 2020.
- On the opportunities and risks of foundation models. arXiv preprint arXiv:2108.07258, 2021.
- Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
- Dark experience for general continual learning: a strong, simple baseline. Advances in neural information processing systems, 33:15920–15930, 2020.
- New insights on reducing abrupt representation change in online continual learning. arXiv preprint arXiv:2104.05025, 2021.
- Online learned continual compression with adaptive quantization modules. In International Conference on Machine Learning, pages 1240–1250. PMLR, 2020.
- Co2l: Contrastive continual learning. In Proceedings of the IEEE/CVF International conference on computer vision, pages 9516–9525, 2021.
- Swad: Domain generalization by seeking flat minima. Advances in Neural Information Processing Systems, 34:22405–22418, 2021.
- Efficient lifelong learning with a-gem. arXiv preprint arXiv:1812.00420, 2018.
- Lifelong language pretraining with distribution-specialized experts. In International Conference on Machine Learning, pages 5383–5395. PMLR, 2023.
- Dna: Domain generalization with diversified neural averaging. In International Conference on Machine Learning, pages 4010–4034. PMLR, 2022.
- Think you have solved question answering? try arc, the ai2 reasoning challenge. arXiv preprint arXiv:1803.05457, 2018.
- Probing representation forgetting in supervised and unsupervised continual learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 16712–16721, 2022.
- Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pages 248–255. Ieee, 2009.
- Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
- Lmflow: An extensible toolkit for finetuning and inference of large foundation models. arXiv preprint arXiv:2306.12420, 2023.
- Raft: Reward ranked finetuning for generative foundation model alignment. arXiv preprint arXiv:2304.06767, 2023.
- Continual pre-training of language models for math problem understanding with syntax-aware memory network. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 5923–5933, 2022.
- Finetune like you pretrain: Improved finetuning of zero-shot vision models. arXiv preprint arXiv:2212.00638, 2022.
- Finetune like you pretrain: Improved finetuning of zero-shot vision models. ArXiv, abs/2212.00638, 2022.
- Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016.
- The many faces of robustness: A critical analysis of out-of-distribution generalization. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 8340–8349, 2021.
- Natural adversarial examples. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 15262–15271, 2021.
- Lora: Low-rank adaptation of large language models. ArXiv, abs/2106.09685, 2021.
- Continual learning for text classification with information disentanglement based regularization. arXiv preprint arXiv:2104.05489, 2021.
- Editing models with task arithmetic. arXiv preprint arXiv:2212.04089, 2022.
- Scaling up visual and vision-language representation learning with noisy text supervision. In International conference on machine learning, pages 4904–4916. PMLR, 2021.
- Alleviating representational shift for continual fine-tuning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 3810–3819, 2022.
- What disease does this patient have? a large-scale open domain question answering dataset from medical exams. arXiv preprint arXiv:2009.13081, 2020.
- Pubmedqa: A dataset for biomedical research question answering. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), pages 2567–2577, 2019.
- Lifelong pretraining: Continually adapting language models to emerging corpora. arXiv preprint arXiv:2110.08534, 2021.
- Last layer re-training is sufficient for robustness to spurious correlations. arXiv preprint arXiv:2204.02937, 2022.
- Overcoming catastrophic forgetting in neural networks. Proceedings of the national academy of sciences, 114(13):3521–3526, 2017.
- Fine-tuning can distort pretrained features and underperform out-of-distribution. arXiv preprint arXiv:2202.10054, 2022.
- Race: Large-scale reading comprehension dataset from examinations. arXiv preprint arXiv:1704.04683, 2017.
- Blip-2: Bootstrapping language-image pre-training with frozen image encoders and large language models. arXiv preprint arXiv:2301.12597, 2023.
- Branch-train-merge: Embarrassingly parallel training of expert language models. arXiv preprint arXiv:2208.03306, 2022.
- Learning without forgetting. IEEE transactions on pattern analysis and machine intelligence, 40(12):2935–2947, 2017.
- Bayesian invariant risk minimization. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 16021–16030, 2022.
- Zin: When and how to learn invariance without environment partition? Advances in Neural Information Processing Systems, 35:24529–24542, 2022.
- Multimodality helps unimodality: Cross-modal few-shot learning with multimodal models. ArXiv, abs/2301.06267, 2023.
- Tangent model composition for ensembling and continual fine-tuning. arXiv preprint arXiv:2307.08114, 2023.
- Continual mixed-language pre-training for extremely low-resource neural machine translation. arXiv preprint arXiv:2105.03953, 2021.
- An empirical study of catastrophic forgetting in large language models during continual fine-tuning. arXiv preprint arXiv:2308.08747, 2023.
- Continual learning in task-oriented dialogue systems. arXiv preprint arXiv:2012.15504, 2020.
- Why there are complementary learning systems in the hippocampus and neocortex: insights from the successes and failures of connectionist models of learning and memory. Psychological review, 102(3):419, 1995.
- Training language models to follow instructions with human feedback. Advances in Neural Information Processing Systems, 35:27730–27744, 2022.
- Medmcqa: A large-scale multi-subject multi-choice dataset for medical domain question answering. In Gerardo Flores, George H Chen, Tom Pollard, Joyce C Ho, and Tristan Naumann, editors, Proceedings of the Conference on Health, Inference, and Learning, volume 174 of Proceedings of Machine Learning Research, pages 248–260. PMLR, 07–08 Apr 2022.
- Task-specific skill localization in fine-tuned language models. arXiv preprint arXiv:2302.06600, 2023.
- Instruction tuning with gpt-4. arXiv preprint arXiv:2304.03277, 2023.
- Moment matching for multi-source domain adaptation. In Proceedings of the IEEE/CVF international conference on computer vision, pages 1406–1415, 2019.
- Combined scaling for zero-shot transfer learning. ArXiv, abs/2111.10050, 2021.
- Elle: Efficient lifelong pre-training for emerging data. arXiv preprint arXiv:2203.06311, 2022.
- Simple and fast group robustness by automatic feature reweighting. arXiv preprint arXiv:2306.11074, 2023.
- Learning transferable visual models from natural language supervision. In International Conference on Machine Learning, pages 8748–8763. PMLR, 2021.
- Progressive prompts: Continual learning for language models. arXiv preprint arXiv:2301.12314, 2023.
- icarl: Incremental classifier and representation learning. In Proceedings of the IEEE conference on Computer Vision and Pattern Recognition, pages 2001–2010, 2017.
- Do imagenet classifiers generalize to imagenet? In International conference on machine learning, pages 5389–5400. PMLR, 2019.
- Learning to learn without forgetting by maximizing transfer and minimizing interference. arXiv preprint arXiv:1810.11910, 2018.
- Online structured laplace approximations for overcoming catastrophic forgetting. Advances in Neural Information Processing Systems, 31, 2018.
- Domain-adjusted regression or: Erm may already learn features sufficient for out-of-distribution generalization. arXiv preprint arXiv:2202.06856, 2022.
- Multitask prompted training enables zero-shot task generalization. arXiv preprint arXiv:2110.08207, 2021.
- Bloom: A 176b-parameter open-access multilingual language model. arXiv preprint arXiv:2211.05100, 2022.
- Progress & compress: A scalable framework for continual learning. In International conference on machine learning, pages 4528–4537. PMLR, 2018.
- Lamol: Language modeling for lifelong language learning. arXiv preprint arXiv:1909.03329, 2019.
- Provably invariant learning without domain information. 2023.
- Alpaca: A strong, replicable instruction-following model. Stanford Center for Research on Foundation Models. https://crfm. stanford. edu/2023/03/13/alpaca. html, 3(6):7, 2023.
- Galactica: A large language model for science. arXiv preprint arXiv:2211.09085, 2022.
- Id and ood performance are sometimes inversely correlated on real-world datasets. arXiv preprint arXiv:2209.00613, 2022.
- Trainable projected gradient method for robust fine-tuning. ArXiv, abs/2303.10720, 2023.
- Llama: Open and efficient foundation language models. arXiv preprint arXiv:2302.13971, 2023.
- Jeffrey S Vitter. Random sampling with a reservoir. ACM Transactions on Mathematical Software (TOMS), 11(1):37–57, 1985.
- Learning robust global representations by penalizing local predictive power. Advances in Neural Information Processing Systems, 32, 2019.
- A comprehensive survey of continual learning: Theory, method and application. arXiv preprint arXiv:2302.00487, 2023.
- Self-instruct: Aligning language model with self generated instructions. arXiv preprint arXiv:2212.10560, 2022.
- Robust fine-tuning of zero-shot models. 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pages 7949–7961, 2021.
- Pretrained language model in continual learning: A comparative study. In International Conference on Learning Representations, 2021.
- Explicit inductive bias for transfer learning with convolutional networks. In International Conference on Machine Learning, pages 2825–2834. PMLR, 2018.
- Semantic drift compensation for class-incremental learning. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 6982–6991, 2020.
- Investigating the catastrophic forgetting in multimodal large language models. arXiv preprint arXiv:2309.10313, 2023.
- Contrastive adapters for foundation model group robustness. arXiv preprint arXiv:2207.07180, 2022.
- Continual sequence generation with adaptive compositional modules. arXiv preprint arXiv:2203.10652, 2022.
- On model selection consistency of lasso. The Journal of Machine Learning Research, 7:2541–2563, 2006.
- Preventing zero-shot transfer degradation in continual learning of vision-language models. arXiv preprint arXiv:2303.06628, 2023.
- Model agnostic sample reweighting for out-of-distribution learning. In International Conference on Machine Learning, pages 27203–27221. PMLR, 2022.
- Sparse invariant risk minimization. In International Conference on Machine Learning, pages 27222–27244. PMLR, 2022.
- Yong Lin (77 papers)
- Hangyu Lin (11 papers)
- Renjie Pi (37 papers)
- Jipeng Zhang (46 papers)
- Shizhe Diao (47 papers)
- Haoxiang Wang (35 papers)
- Han Zhao (159 papers)
- Yuan Yao (292 papers)
- Tong Zhang (569 papers)
- Wei Xiong (172 papers)
- Jianmeng Liu (5 papers)
- Rui Pan (67 papers)
- Wenbin Hu (50 papers)
- Hanning Zhang (12 papers)
- Hanze Dong (43 papers)
- Nan Jiang (210 papers)
- Heng Ji (266 papers)