- The paper introduces DeRa, a method that dynamically adjusts the regularization parameter during decoding, eliminating the need for costly retraining.
- It leverages autoregressive approximations and a new parameter to effectively balance human-aligned output and model expressiveness during text generation.
- Experiments in summarization and hallucination mitigation demonstrate that DeRa achieves comparable quality to retrained models with significant computational savings.
Decoding-time Realignment of LLMs
Introduction
The study "Decoding-time Realignment of LLMs" presents a novel method referred to as Decoding-time Realignment (DeRa) aimed at dynamically adjusting the regularization strength of LLMs during the decoding process. LLMs, while proficient at generating text, often require alignment with human preferences to minimize undesirable outputs such as bias or factual inaccuracies. Typically, a trade-off exists between human alignment and maintaining the model's expressiveness, controlled via a regularization parameter, typically involving Kullback-Leibler (KL) divergence. Traditional approaches necessitate retraining models at various regularization levels to identify the optimal configuration, which is computationally expensive for large-scale models.
Background
LLMs (LMs) are typically trained using extensive text datasets to predict the next token in a sequence. Subsequent finetuning involves supervised learning on a curated dataset targeting specific behaviors. Despite this refinement, challenges remain in aligning LMs with human preferences via Reinforcement Learning from Human Feedback (RLHF), balancing between maximizing reward and proximity to the initial model. The model aims to maintain capabilities learned during initial training.
The regularization parameter β specifically influences this balance. Larger β values limit alignment effectiveness by adhering closely to the reference model, while smaller values risk "reward hacking," where the model sacrifices overall quality for short-term gains. Traditional approaches iterate over a range of β values via intensive retraining, challenging under computational constraints.
Proposed Method: Decoding-time Realignment
DeRa introduces a shift in paradigm by facilitating realignment at the decoding stage without requiring retraining. It hinges on proving that various alignment states define geometric mixtures between aligned and reference models, modulated by differing weights. At the core, DeRa applies autoregressive approximations, adjusting λ, a new parameter, to regulate KL strength during sequence generation. This approach is codified in algorithmic terms, allowing real-time adjustment at inference without retraining expenses.
A simplified proposition for computing per-token probabilities within the autoregressive model facilitates practical implementation, whereby the combined logits generate samples efficiently across varying values. Efficacy hinges on computational cost-effectiveness, freeing practitioners from extensive trial-and-error parameter sweeps.
Experiments and Results
Length-reward Task
DeRa's effectiveness was demonstrated in a controlled summarization task with a hardcoded length reward encouraging summaries within a specific range. Comparison with retrained models revealed that DeRa closely mirrored results without retraining, substantiating DeRa's approximation as reliable for hyperparameter tuning.
Substantial evidence from tasks like summarization and hallucination mitigation underscore DeRa's production of high-quality outputs across variable KL settings without repeated model adjustments. Comparisons with existing models (e.g., Zephyr-7b) reveal competitive performance, suggesting high potential in efficiency-sensitive environments.
Figure 1: DeRa adjusts alignment levels of LLMs at decoding time. We apply DeRa to Zephyr-7b models.
Computational Considerations
Although DeRa requires dual model runs during decoding, its reduced training overhead compensates for this. Furthermore, DeRa can operate with adjusted training weights when computational constraints demand, suggesting flexibility in balancing budget versus performance targets.
Conclusion
The introduction of DeRa represents a meaningful advancement in computational efficiency for LM alignment tasks. By eschewing the traditional requirements for multiple model retrainings, DeRa affords greater flexibility and resource efficiency in model alignment, underscoring a pragmatic approach to maintaining model performance parallel with alignment goals. As LLMs evolve, the ability to quickly adapt through methods like DeRa will be crucial in meeting diverse application needs without prohibitive computational investments.