In-Context Learning for Extreme Multi-Label Classification (2401.12178v1)
Abstract: Multi-label classification problems with thousands of classes are hard to solve with in-context learning alone, as LLMs (LMs) might lack prior knowledge about the precise classes or how to assign them, and it is generally infeasible to demonstrate every class in a prompt. We propose a general program, $\texttt{Infer--Retrieve--Rank}$, that defines multi-step interactions between LMs and retrievers to efficiently tackle such problems. We implement this program using the $\texttt{DSPy}$ programming model, which specifies in-context systems in a declarative manner, and use $\texttt{DSPy}$ optimizers to tune it towards specific datasets by bootstrapping only tens of few-shot examples. Our primary extreme classification program, optimized separately for each task, attains state-of-the-art results across three benchmarks (HOUSE, TECH, TECHWOLF). We apply the same program to a benchmark with vastly different characteristics and attain competitive performance as well (BioDEX). Unlike prior work, our proposed solution requires no finetuning, is easily applicable to new tasks, alleviates prompt engineering, and requires only tens of labeled examples. Our code is public at https://github.com/KarelDO/xmc.dspy.
Summary
- The paper introduces IReRa, a three-step modular approach (Infer, Retrieve, Rank) that leverages in-context learning for extreme multi-label classification.
- It utilizes DSPy for automated prompt optimization, reducing extensive labeled data requirements compared to traditional finetuning.
- Experimental results show state-of-the-art performance on ESCO tasks and competitive outcomes on biomedical datasets using significantly fewer LM calls.
Here is a summary of the paper "In-Context Learning for Extreme Multi-Label Classification" (2401.12178):
The paper addresses the challenge of applying in-context learning (ICL) from LLMs (LMs) to Extreme Multi-Label Classification (XMC) tasks, where the number of potential classes can be in the tens of thousands. Standard ICL struggles here because LMs may not have intrinsic knowledge of such specific and numerous classes, and it is impractical to include all classes in a prompt. Existing XMC methods often rely on extensive finetuning requiring large datasets or involve complex multi-step inference processes with significant manual tuning (prompt engineering).
To overcome these limitations, the authors propose a general program called Infer--Retrieve--Rank (IReRa), designed specifically for XMC within an in-context learning framework. This program is implemented using the DSPy programming model, which allows for defining modular pipelines and optimizing them declaratively rather than through manual prompt engineering.
The IReRa program operates in three sequential steps for each input document:
- Infer: An initial LM processes the input text and predicts a set of relevant terms or queries that are likely related to the correct labels. This step essentially guides the subsequent retrieval process.
- Retrieve: A frozen retriever uses the predicted terms from the Infer step as queries to search over the large space of potential labels. It ranks all labels based on similarity (e.g., cosine similarity of embeddings). The paper notes that this similarity can be re-weighted using label prior probabilities if available.
- Rank: A second LM takes the original input document and the top candidate labels retrieved in the previous step, and reranks these candidates to produce the final set of predicted labels.
A key practical aspect of this approach is the use of DSPy for optimization. Instead of manual prompt engineering, the IReRa program (defined with minimal "seed" prompts) is automatically optimized using a small number of labeled examples (tens, specifically ~50 validation examples in the experiments). DSPy's compilers, such as BootstrapFewShotWithRandomSearch
, use a zero-shot "Teacher" LM (e.g., GPT-3.5 or GPT-4) to generate high-quality few-shot examples for the "Student" LMs (e.g., Llama-2-7b-chat or GPT-4) used in the Infer and Rank modules. This bootstrapping process selects the best few-shot demonstrations based on performance on the validation set, effectively tuning the program's behavior without changing the underlying LM weights or requiring large labeled datasets for finetuning.
The authors evaluated IReRa on four XMC datasets: BioDEX (biomedical literature extraction with ~24k labels) and three ESCO skill extraction datasets (HOUSE, TECH, TECHWOLF, based on job vacancy snippets with ~14k labels). They used a configuration with Llama-2-7b-chat for Infer, GPT-3.5 as the Infer Teacher, a frozen retriever (BioLORD for BioDEX, all-mpnet-base-v2 for ESCO), and GPT-4 for Rank (both student and teacher).
The experimental results demonstrate the practical effectiveness of IReRa:
- It achieves state-of-the-art results on the ESCO skill extraction tasks, outperforming specialized finetuned systems while requiring orders of magnitude less labeled data and no finetuning.
- On the BioDEX dataset, which has different characteristics (longer inputs, requiring more complex inference), IReRa achieves competitive performance and shows consistent improvement from its modular steps (Infer, Retrieve, Rank, and optimization), even though it doesn't surpass the best finetuned system.
- The optimization process is shown to be a critical driver of performance across tasks.
The paper provides a detailed breakdown of the computational cost. Optimizing the full IReRa program requires approximately 1500 student LM calls and 20 teacher LM calls, which is significantly less than the data requirements for the finetuned systems (~11,500 to ~555,000 labeled examples). Per-input inference requires one call to the Infer LM (Llama-2), one to the Retrieve module, and one to the Rank LM (GPT-4). While the current configuration uses GPT-4 for ranking, which might be costly for high-throughput applications, an ablated version (Infer-Retrieve) using only open-source components is still competitive.
Practical Implementation Aspects:
- Modularity (DSPy): The use of DSPy Signatures and Modules allows developers to define the task requirements and program flow separately from the underlying LM calls and prompts. This makes the system more interpretable and adaptable.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
class InferRetrieveRank(dspy.Module): def __init__(self, infer_sig, rank_sig, retr): super().__init__() self.infer = dspy.ChainOfThought(infer_sig) self.rank = dspy.ChainOfThought(rank_sig) self.retrieve = retr # This is the retriever module def forward(self, text: str) -> Prediction: # Step 1: Infer queries/terms infer_output = self.infer(text=text) preds = extract_labels_from_strings(infer_output.completions.labels) # Parse output # Step 2: Retrieve candidate labels based on inferred terms labels = self.retrieve(preds) # Use retriever # Step 3: Rank retrieved labels ranked_labels = self.rank(text=text, options=labels) # Use Rank LM return dspy.Prediction(labels=ranked_labels.completions.labels)
- Minimal Seed Prompts: Adapting IReRa to a new task primarily involves defining simple DSPy Signatures for the Infer and Rank steps, specifying the input and output fields and a basic task description.
1 2 3 4 5 6 7 8 9 10
class NewTaskInferSignature(dspy.Signature): """Given an input, identify relevant concepts.""" text = dspy.InputField(prefix="Input:") output = dspy.OutputField(prefix="Concepts:", desc="list of comma-separated concepts") class NewTaskRankSignature(dspy.Signature): """Given an input and a list of candidate options, pick the most relevant ones.""" text = dspy.InputField(prefix="Input:") options = dspy.InputField(prefix="Options:", desc="List of comma-separated options to choose from") output = dspy.OutputField(prefix="Relevant Options:", desc="list of comma-separated options")
- Optimization via Bootstrapping: The DSPy optimization process (e.g.,
BootstrapFewShotWithRandomSearch
) automates the creation of effective few-shot prompts using a small validation set and a teacher model. This significantly reduces manual prompt engineering effort.1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
# Assuming you have devset (validation examples) teacher = dspy.GPT3Dot5Turbo() # Or GPT4 student_infer = dspy.Llama27B() # Or other LLM student_rank = dspy.GPT4() # Or other LLM infer_sig = NewTaskInferSignature rank_sig = NewTaskRankSignature retriever = MyRetriever() # Implement or use a vector database retriever # Initialize the program with initial signatures and LMs/retriever unoptimized_program = InferRetrieveRank(infer_sig, rank_sig, retriever) # Configure optimizers # Bootstrap Infer using the teacher, evaluate on devset infer_optimizer = dspy.BootstrapFewShotWithRandomSearch(metric=rp_at_k_metric, teacher=teacher) # Bootstrap Rank using the teacher, evaluate on devset # Note: Rank optimization loop requires the outputs of Infer-Retrieve steps rank_optimizer = dspy.BootstrapFewShotWithRandomSearch(metric=rp_at_k_metric, teacher=teacher) # Optimize the program sequentially as done in the paper # First optimize Infer (part of Infer-Retrieve), then optimize Rank (part of Infer-Retrieve-Rank) # This requires setting up the optimization flow within DSPy, # potentially by optimizing sub-modules or chaining optimizers if the framework supports it directly for sequential optimization. # A simplified conceptual flow in DSPy might look like this: optimized_program = infer_optimizer.compile(unoptimized_program, trainset=trainset_for_bootstrap, valset=devset) # Then compile the part involving Rank, potentially requiring a custom compilation step or recompiling the whole program # with a metric that evaluates the final output and ensures the Infer part uses the newly compiled Infer module. # The paper implies a sequential optimization: first Infer then Rank. # Example of using the optimized program: # prediction = optimized_program(text="Input document text...")
- Retriever Choice: The choice of retriever (e.g., general sentence transformer, domain-specific embeddings) is a configurable component and can impact performance, especially for domain-specific tasks like BioDEX.
- Prior Integration: The method allows incorporating label frequency priors to re-weight retrieval scores, which can be helpful if label distribution is skewed and prior statistics are available. s~i=si⋅log10(A⋅pi+10). The hyperparameter A controls the strength of this influence and can be tuned on a small validation set.
In summary, IReRa provides a practical, data-efficient, and less brittle alternative to traditional finetuning and manual prompt engineering for XMC tasks by leveraging the modularity and optimization capabilities of the DSPy framework. It demonstrates that combining frozen LMs and retrievers in a structured program, guided by minimal labeled data and automatic optimization, can yield state-of-the-art performance. The main limitation for practical deployment in some scenarios might be the reliance on a powerful, potentially costly, closed-source LM like GPT-4 for the ranking step, though the Infer-Retrieve baseline offers a fully open-source alternative.
Related Papers
- Learning to Retrieve In-Context Examples for Large Language Models (2023)
- Learning To Retrieve Prompts for In-Context Learning (2021)
- In-Context Learning for Text Classification with Many Labels (2023)
- "In-Context Learning" or: How I learned to stop worrying and love "Applied Information Retrieval" (2024)
- Optimizing Instructions and Demonstrations for Multi-Stage Language Model Programs (2024)