- The paper introduces the CoRAG framework that trains LLMs to iteratively generate retrieval chains, achieving over 10 points improvement in exact match scores.
- The paper employs rejection sampling to automatically generate intermediate query chains, enabling multi-task training for sub-query, sub-answer, and final answer prediction.
- The paper demonstrates state-of-the-art performance on multi-hop QA and KILT benchmarks by adapting decoding strategies to control test-time compute.
The paper introduces Chain-of-Retrieval Augmented Generation (CoRAG) framework, which trains LLMs to iteratively retrieve and reason over information before generating a final answer. Conventional Retrieval Augmented Generation (RAG) models typically perform a single retrieval step, which can be insufficient for complex queries. The CoRAG framework addresses this limitation by enabling the model to dynamically reformulate queries based on the current state of retrieved information.
To train CoRAG, the authors employ rejection sampling to generate intermediate retrieval chains, thereby augmenting existing RAG datasets that usually only contain the final correct answer. At inference time, various decoding strategies are used to control the model's compute, by adjusting the length and number of sampled retrieval chains. Experiments on several benchmarks show the efficacy of CoRAG, especially in multi-hop question answering tasks. The paper reports that CoRAG achieves over 10 points improvement in exact match (EM) score compared to strong baselines. On the Knowledge Intensive Language Tasks (KILT) benchmark, CoRAG establishes new state-of-the-art performance across a diverse set of knowledge-intensive tasks. The paper includes analyses of CoRAG's scaling behavior.
The CoRAG framework is shown in Figure 1. The key components of CoRAG include retrieval chain generation using rejection sampling, model training with augmented datasets, and test-time compute scaling strategies.
Most RAG datasets only provide a query Q and a final answer A, without intermediate retrieval steps. The paper presents an automated method for generating retrieval chains through rejection sampling. Each chain consists of sub-queries Q1:L={Q1,Q2,…,QL} and corresponding sub-answers A1:L, where L is the maximum chain length. The sub-query Qi=LLM(Q<i,A<i,Q) is generated by sampling a LLM based on query Q and preceding sub-queries and sub-answers. To generate the sub-answer Ai, the top-k most relevant documents D1:k(i) are retrieved using a text retriever with Qi as the search query. An LLM is then prompted to generate the answer Ai=LLM(Qi,D1:k(i)). This process iterates until the chain reaches the maximum length L or Ai matches the correct answer A.
To assess the quality of a retrieval chain, the log-likelihood of the correct answer, logP(A∣Q,Q1:L,A1:L), is calculated conditioned on the chain information. The retrieval chain with the highest log-likelihood score is then selected to augment the original question answering (QA)-only dataset.
Each training instance in the augmented dataset is represented as a tuple (Q,A,Q1:L,A1:L), along with the top-k retrieved documents for query Q and each sub-query. An LLM is fine-tuned on this augmented dataset using the standard next-token prediction objective in a multi-task learning framework. The model is simultaneously trained on three tasks: next sub-query prediction, sub-answer prediction, and final answer prediction. The same prompt templates used in retrieval chain generation are also used in model training, with the addition of the top retrieved documents D1:k for the original query Q as input for the final answer prediction task. The loss functions for each of the tasks are:
Lsub_query=−logP(Qi∣Q,Q<i,A<i),
Lsub_answer=−logP(Ai∣Qi,D1:k(i)), and
Lfinal_answer=−logP(A∣Q,Q1:L,A1:L,D1:k).
Where:
- Lsub_query is the loss for sub-query prediction,
- P(Qi∣Q,Q<i,A<i) is the probability of the i-th sub-query given the original query and previous sub-queries and answers
- Lsub_answer is the loss for sub-answer prediction
- P(Ai∣Qi,D1:k(i)) is the probability of the i-th sub-answer given the i-th sub-query and the top-k retrieved documents,
- Lfinal_answer is the loss for final answer prediction,
- P(A∣Q,Q1:L,A1:L,D1:k) is the probability of the final answer given the original query, all sub-queries and sub-answers, and the top-k retrieved documents.
The cross-entropy loss is computed only for the target output tokens.
Given a trained CoRAG model, several decoding strategies are proposed to control the trade-off between model performance and test-time compute, where test-time compute is measured by the total number of token consumptions. The decoding strategies are:
- Greedy Decoding: Uses greedy decoding to generate L sub-queries and their corresponding sub-answers sequentially. The final answer is generated using the same prompt template as during training.
- Best-of-N Sampling: Samples N retrieval chains with a temperature of 0.7, then selects the best chain to generate the final answer. The conditional log-likelihood of "No relevant information found" is used as a penalty score for each chain, and the chain with the lowest penalty is chosen.
- Tree Search: Implements a breadth-first search (BFS) variant with retrieval chain rollouts. At each step, the current state is expanded by sampling several sub-queries. For each expanded state, multiple rollouts are performed, and the average penalty score of these rollouts is computed. The state with the lowest average penalty score is retained for further expansion.
The maximum length of the retrieval chain L can be adjusted across all decoding strategies to control test-time compute. For best-of-N sampling, the number of sampled chains N provides an alternative way to scale test-time compute. In tree search, the number of rollouts and expansion size are additional hyperparameters.
The CoRAG framework is evaluated using two sets of benchmarks: (1) a collection of multi-hop QA datasets, including 2WikiMultihopQA, HotpotQA, Bamboogle, and MuSiQue, and (2) the KILT benchmark, which encompasses a wide range of knowledge-intensive tasks. The multi-hop QA datasets assess the model's multi-hop reasoning capability, while the KILT benchmark evaluates the framework's generalization across diverse tasks. The open-source Llama-3.1-8B-Instruct model is prompted to perform rejection sampling for each training dataset, and E5-large is used as the text retriever for intermediate retrieval steps. The retrieval corpus is the English Wikipedia from KILT, containing approximately 36 million passages. The selected retrieval chains are used to augment the original QA-only datasets for model training.
The evaluation metrics include exact match (EM) and F1 scores for the multi-hop QA datasets. For the KILT benchmark, the model's predictions are submitted to the official evaluation server, and the downstream metrics are reported on the hidden test set. Ablation studies on the KILT benchmark report public validation set results.
Full-parameter fine-tuning is conducted on the augmented datasets, starting from the Llama-3.1-8B-Instruct checkpoint. Two separate models are trained: one for the multi-hop QA datasets and another for the KILT benchmark. The multi-hop QA dataset comprises 125k training instances, while the KILT benchmark includes 660k instances after sub-sampling. The model is fine-tuned for 1 epoch with a maximum sequence length of 3k tokens. For the KILT benchmark, an E5-Mistral retriever and a RankLLaMA re-ranker are fine-tuned on the respective training set to improve ranking quality.
In Table 1, CoRAG-8B is compared against several models, including few-shot Llama-3.1-8B-Instruct, GPT-4o, Self-RAG-7B, ITER-RETGEN, DRAG, IterDRAG, and Search-o1-32B. A fine-tuned Llama-8B baseline using the E5-large retriever is also included. CoRAG-8B outperforms all baselines except on the Bamboogle dataset, despite being based on a weaker LLM.
Table 2 presents strong systems on the KILT benchmark, including KILT-RAG, SEAL, Atlas-11B, RA-DIT 65B, and FiD with RS. The CoRAG-8B model achieves new state-of-the-art performance across all tasks, except for FEVER.
The model allows for scaling test-time compute to potentially improve performance without updating model weights. The retrieval chain length L and the number of sampled chains N for best-of-N sampling are the key factors. Increasing the retrieval chain length L results in substantial performance improvements when L is small, but the gains diminish as L increases. Increasing N for best-of-N sampling yields mixed effects depending on the dataset. The Pareto frontier between the EM score and token consumption approximately follows a log-linear trajectory for up to 128k tokens, but the scaling behavior varies across different datasets.
The paper explores self-improvement through iterative training. A trained CoRAG model can generate new sets of retrieval chains. However, the results are mixed, with performance improvements on the 2WikiMultihopQA dataset but declines on other datasets, suggesting that instruction-tuned LLMs already possess a strong ability to generate high-quality retrieval chains.
The influence of various text retrievers at test time is also investigated. When substituting the E5-large dense retriever with E5-base and BM25, consistent performance gains are observed when investing more test-time compute. Utilizing Llama-3B achieves very close performance compared to the 8B model, whereas Llama-1B exhibits a noticeable performance drop.
The paper analyzes whether the chain-of-retrieval mechanism always helps. Multi-hop QA datasets are specifically designed to evaluate complex reasoning capabilities and are expected to benefit from the chain-of-retrieval mechanism. In contrast, for tasks that a single retrieval step is typically sufficient, the advantage tends to be marginal. This implies that decoding strategies should be adaptive based on the complexity of the query.
Instead of always performing L retrieval steps, a model variant is explored that learns to stop at test time. After each retrieval step, the model is prompted to predict whether the information gathered thus far suffices to answer the query. By adjusting the logit bias of the "Yes" token, the early stopping behavior can be controlled. While early stopping can save some amount of token quota, it comes at the cost of performance degradation.
The paper introduces CoRAG, a framework that trains LLMs to conduct iterative retrieval and reasoning to answer complex queries. The intermediate retrieval chains are automatically generated via rejection sampling, eliminating the need for manual annotation. The paper shows that the CoRAG-8B achieves state-of-the-art performance on both multi-hop QA datasets and the KILT benchmark.