- The paper introduces a novel two-stage training method that integrates chain-of-thought prompting with Direct Preference Optimization to enhance text reranking performance.
- The approach leverages iterative reasoning with a sliding window strategy to efficiently rerank large document sets while maintaining state-of-the-art benchmark results.
- The proposed method effectively preserves the LLM’s general language capabilities, outperforming baselines and mitigating degradation seen in traditional fine-tuning approaches.
ChainRank-DPO is a novel approach for fine-tuning LLMs for text reranking tasks while aiming to preserve their general-purpose capabilities. Traditional supervised fine-tuning (SFT) for reranking, as seen in models like RankVicuna and RankZephyr, often leads to a degradation of the LLM's broader reasoning and generation abilities. ChainRank-DPO addresses this by combining a Chain-of-Thought (CoT) prompting strategy with a two-stage SFT followed by Direct Preference Optimization (DPO) pipeline.
The core idea is to frame the listwise reranking task as an iterative CoT process. Instead of simply outputting a ranked list, the model is prompted to select the most relevant passage step-by-step from the remaining set, explicitly showing its reasoning process. This CoT approach is implemented using a specific prompt template that instructs the model to list the selected passages at each step until all are ordered (Figure \ref{fig:user prompt}).
The training of ChainRank-DPO consists of two stages:
- Stage 1: Supervised Fine-Tuning (SFT): The base model, LLaMA3-8B-Instruct, is fine-tuned using a dataset of query-document pairs and their corresponding ground-truth ranking orders generated by teacher models (GPT-3.5 and GPT-4). The training data, based on MS MARCO v1, includes 35k instances from GPT-3.5 and 5k from GPT-4. 90% of this data is used for SFT. The SFT aims to teach the model the CoT reranking format and the task itself. Full fine-tuning of the 8B parameter model was performed for three epochs on four NVIDIA A100 80GB GPUs, taking approximately 39 hours.
- Stage 2: Chain DPO: This stage further refines the model using a DPO approach adapted for sequential reasoning. The remaining 10% of the training data is used here. The ChainRank-SFT model generates multiple ranking predictions for each prompt. These predictions are compared to the ground-truth (teacher) ranking. A preference dataset (x,sw,sl,so) is constructed, where x is the prompt, so represents the initial steps where the model's output sequence overlaps with the ground truth, sw is the sequence of steps following so that align with the ground truth, and sl is the sequence of steps following so that diverge from the ground truth. The DPO objective is modified to maximize the likelihood of the correct steps (sw) while minimizing the likelihood of incorrect steps (sl), conditioned on the overlapping prefix (so). This helps the model correct errors in its step-by-step reasoning. This stage is trained for one epoch on four NVIDIA A100 80GB GPUs, taking about eight hours.
For practical application, ChainRank-DPO uses a sliding window strategy to rerank a large number of passages (e.g., 100 per query), processing documents in chunks (e.g., window size 20, stride 10). This is necessary due to the LLM's context window limitations, even with LLaMA3's 8192 tokens.
The model's performance was evaluated on standard information retrieval benchmarks (TREC DL19 and DL20, BEIR datasets like NFC, COVID, FIQA) using nDCG@10, and on the Massive Multitask Language Understanding (MMLU) benchmark to assess general capabilities.
Key results demonstrate the practical effectiveness of ChainRank-DPO:
- Reranking Performance: ChainRank-SFT and ChainRank-DPO consistently outperform baselines like RankZephyr and even the teacher model RankGPT4 on TREC DL19 and DL20, as well as BEIR datasets (Table \ref{table:main}). ChainRank-DPO shows further improvements over ChainRank-SFT, highlighting the value of the DPO stage.
- General Capability Preservation: Crucially, ChainRank-SFT and ChainRank-DPO maintain MMLU scores comparable to the original LLaMA3-8B-Instruct model. In contrast, RankVicuna shows a significant drop, and RankZephyr completely loses its general text generation abilities (Table \ref{table:main}, Appendix \ref{sec:LLM task}). This indicates that the ChainRank approach successfully balances task-specific performance with general model integrity.
- Inference Cost: The paper explores the trade-off between performance and inference cost by adjusting the step interval at which the model outputs the ranking order during evaluation. While generating every step is necessary for training, inference can be faster by generating ranks less frequently, albeit with a slight performance drop (Figure \ref{fig:comparison}, Appendix \ref{CoT strategy}). The total inference latency for reranking 100 passages with a 20/10 window/stride is shown to be competitive, and can be reduced further with parallel processing.
Practical implementation considerations include the computational resources required for fine-tuning (multiple high-end GPUs), the need for carefully formatted CoT prompts, and the sliding window strategy for handling large document sets, which introduces overhead compared to models that can process all documents simultaneously. The context window size of the base LLM remains a practical limitation on the number of passages that can be reranked within a single window.
In summary, ChainRank-DPO offers a practical method to leverage the reasoning capabilities of LLMs for text reranking by explicitly guiding the model through a step-by-step ranking process and refining this process with a specialized DPO objective, achieving state-of-the-art ranking performance without sacrificing the model's general language understanding and generation skills.