- The paper introduces an end-to-end clustering strategy that trains multi-vector models to produce inherently clusterable token embeddings.
- CRISP consistently outperforms fixed-token pruning methods, with configurations like C8x32 achieving improved retrieval (54.5 vs 54.3 NDCG@10) and significant compression.
- The approach demonstrates a denoising effect by consolidating semantically similar tokens, enabling efficient representation with minimal performance loss.
CRISP (Clustered Representations with Intrinsic Structure Pruning) (2505.11471) is a novel multi-vector training method designed to address the significant storage and computational overheads of models like ColBERT, while maintaining or improving their state-of-the-art retrieval performance. Multi-vector models represent queries and documents as sets of contextualized token-level embeddings, and their similarity is computed using Chamfer Similarity (MaxSim), defined as:
Chamfer(Q,D)=q∈Q∑x∈Dmax⟨q,x⟩
where Q and D are the sets of query and document vectors, respectively. While this provides fine-grained expressiveness, generating numerous vectors per input scales up representation size dramatically, leading to increased storage needs and quadratic runtime complexity in the number of vectors for scoring.
Existing approaches to mitigate this include quantizing embeddings (e.g., ColBERTv2), pruning tokens based on importance, or clustering embeddings after the model is trained (post-hoc clustering). Post-hoc clustering reduces the number of vectors by representing clusters with their centroids, but its effectiveness is limited by the fact that the original embeddings were not explicitly trained to be clusterable. Token pruning removes information entirely, which can be detrimental.
CRISP proposes to integrate clustering directly into the end-to-end training process. Instead of clustering frozen embeddings, CRISP trains the multi-vector model to produce embeddings that are inherently clusterable. During training, K-means clustering is applied to the token embeddings of both queries and documents, and the Chamfer loss is computed using the resulting cluster centroids. This forces the model to learn representations where semantically similar tokens are grouped together, enabling effective reduction of the number of vectors by using cluster centroids as the final representation.
The paper evaluates CRISP by fine-tuning a multi-vector model based on the Gemma2B backbone on a large training dataset. Various pruning strategies are applied during this training phase, ensuring that the model learns representations suitable for the chosen reduction method. The strategies compared include:
- Fixed-Token Pruning: Selects a fixed number of vectors based on position (Tail Pruning, keeping the last k) or uniform sampling (K-Spacing, keeping every K-th).
- Clustering-Based Pruning (CRISP): Applies K-means clustering to obtain a reduced set of centroids. Two variants are tested:
- Fixed-Size Clustering: Uses a predefined number of clusters (kq for query, kd for document).
- Relative-Size Clustering: Sets the number of clusters as a percentage of the original sequence length.
The performance is evaluated using NDCG@10 on the BEIR benchmark, comparing against the unpruned multi-vector (MV) baseline, a single-vector (SV) baseline, and external state-of-the-art models.
The experimental results demonstrate the advantages of CRISP:
- Fixed-token pruning methods (Tail Pruning, K-Spacing) generally perform significantly worse than the unpruned MV baseline, indicating that simply dropping or sampling vectors degrades learned representation quality when trained under these constraints.
- CRISP models consistently and significantly outperform fixed-token pruning methods.
- The CRISP C8x32 configuration, which uses 8 query and 32 document clusters, slightly surpasses the unpruned MV baseline performance (54.5 vs 54.3 NDCG@10 on average) while achieving substantial compression (document representation size is reduced by ~2.9x, query by ~3.9x on average).
- A more aggressive variant, CRISP C4x8 (4 query, 8 document clusters), achieves much higher compression rates (document size reduced by ~11x, query by ~7.9x on average) with only a modest 3.6% drop in average NDCG@10.
- CRISP shows a "denoising" effect, improving performance on some datasets (e.g., ArguAna, Scidocs, NQ) even with aggressive compression. This is attributed to clustering consolidating semantically similar tokens, including filtering out noise from stopwords or repetitive phrases.
- Compared to post-hoc clustering methods reported in related work (2409.14683), CRISP offers considerably better compression-quality trade-offs and can compress both query and document representations, whereas some post-hoc methods only compress documents.
Implementation considerations include the sensitivity of the model performance to hyperparameters like learning rate and L2 regularization, and the crucial role of task-specific instruction prefixes for adapting the base LLM (Gemma2B) to retrieval tasks. The clustering step during training involves running K-means on mini-batches of embeddings.
The paper concludes that CRISP provides an effective method for training efficient multi-vector retrieval models by learning intrinsically clusterable representations. This approach allows for significant representation size reduction with minimal or even positive impact on retrieval quality, bridging the gap between expensive multi-vector models and efficient single-vector models. A key limitation is that the number of clusters k must be fixed in advance, preventing the model from dynamically adjusting representation size based on input complexity. Future work could explore methods for adaptive cluster counts.