- The paper introduces TAROT, a framework that uses an improved whitened feature distance to select a subset of candidate data by minimizing the optimal transport distance.
- It employs gradient embedding, dimensionality reduction, and Cholesky whitening to overcome biases in traditional influence estimation methods.
- Experimental validations show TAROT’s effectiveness in semantic segmentation, motion prediction, and instruction tuning, achieving superior performance with reduced data.
This paper introduces TAROT (Targeted Data Selection via Optimal Transport), a framework designed to select the most relevant data from a large candidate pool to improve model performance on a specific target domain. The authors identify limitations in previous targeted data selection methods, which often rely on influence-based greedy heuristics. These heuristics can be suboptimal, especially when target data distributions are complex and multimodal, due to (i) the disproportionate impact of dominant features in high-dimensional influence estimation and (ii) restrictive linear additive assumptions in greedy selection strategies.
TAROT addresses these challenges by first proposing an improved influence estimation metric called whitened feature distance (dZw). This metric is designed to mitigate bias from dominant feature components. The process involves several steps:
- Gradient Embedding: For each data sample z=(x,y), an embedding ϕ(z) is created by summing its loss gradients ∇L(z;θi) with respect to model parameters θi over several training checkpoints.
ϕ(z)=i=1∑T∇L(z;θi)
- Dimensionality Reduction: High-dimensional gradients are projected to a lower-dimensional space using random projections.
$\phi(z)^{\text{proj}_i = \mathcal{P}^{T} \cdot \phi(z)_i$
- Whitening: The projected gradients are centered, and then Cholesky whitening is applied. This involves computing the covariance matrix Σ of the centered projected gradients, performing Cholesky decomposition (Σ=LL⊤), and then transforming the gradients:
ϕ(z)w=L−1ϕ~(z)proj
This step decorrelates features and scales them to have unit variance.
- Normalization: Each whitened gradient vector is normalized to unit length:
ϕ^(z)i=∣∣ϕ(z)iw∣∣2ϕ(z)iw
- Distance Calculation: The whitened feature distance between two samples z and z′ is the L2 norm of the difference between their normalized, whitened gradient embeddings:
dZw(z,z′)=∣∣ϕ^(z)−ϕ^(z′)∣∣2
This dZw serves as the cost function c(x,y) for calculating the Optimal Transport (OT) distance between the distributions of the selected candidate data and the target data. The core idea of TAROT is to select a subset of candidate data Ds from the full candidate pool Dc such that the OT distance to the target dataset Dt is minimized.
dOT(Ds,Dt)=π∈Π(αs,αt)min∫Z×ZdZw(z,z′)dπ(z,z′),
where αs and αt are empirical distributions of Ds and Dt. The Sinkhorn algorithm is used for efficient OT computation.
Since solving the subset selection problem to minimize OT distance is combinatorial, TAROT employs a greedy approach. Two selection schemes are proposed:
- Fixed-Size Selection:
- Aims to select a predetermined number of samples S.
- Iteratively, for k=1,2,…, it considers adding the k-th nearest candidate samples (based on dZw) from Dc for each sample in Dt to the selected set Ds.
- If adding all k-th nearest neighbors exceeds S, it uses the OT dual potential ϕi to rank these candidates and selects the top S−∣Ds∣ ones. The potential ϕi estimates the benefit of adding sample zi to Ds in terms of minimizing OT distance to Dt.
- The pseudocode is available in Algorithm 1 of the supplementary material.
- OT-Distance Minimization Selection (OTM):
- Aims to find an optimal selection ratio by directly minimizing the OT distance.
- Uses k-fold cross-validation on the target dataset Dt.
- In each fold, $1/k$ of Dt is used for guiding selection (finding nearest neighbors from Dc), and the OT distance is evaluated against the remaining (k−1)/k of Dt.
- Samples are iteratively added (nearest neighbors first). The process stops when adding more samples starts to increase the OT distance to the validation part of Dt.
- The final selected set is the union of selections from all k folds (typically k=10).
- The pseudocode is available in Algorithm 2 of the supplementary material.
Additionally, TAROT proposes Data Weighting with OT Potential. After selecting Ds, the OT potentials (Equation \ref{eq:potential}) for each selected sample zi∈Ds are calculated. These potentials are scaled to positive integer weights wi, indicating how many times a sample should be repeated within a training epoch. The sum of weights ∑wi can be set to a customizable repetition factor R.
Implementation and Application:
- Gradient Feature Extraction: Requires a pre-trained model on the candidate (or a mix of candidate and target) data to extract gradients. The paper uses gradients from later checkpoints. For example, in semantic segmentation, DeepLabV3 with a ResNet50 backbone was used, with gradients from the last four checkpoints.
- Computational Cost: The main costs involve gradient computation, random projection, whitening (Cholesky decomposition of an N×N covariance matrix where N is projected dimension), and OT distance calculation (Sinkhorn algorithm, typically O(NcNtlog(⋅)) for dense cost matrix or faster with specialized solvers if cost matrix is sparse, where Nc,Nt are dataset sizes). The greedy selection involves multiple OT computations.
- Libraries: The random projection leverages efficient CUDA implementations (e.g., from TRAK). Standard libraries like NumPy/SciPy can be used for Cholesky decomposition, and POT (Python Optimal Transport) library for Sinkhorn algorithm.
- Code: The authors provide code at
https://github.com/vita-epfl/TAROT
.
Experimental Validation:
TAROT was evaluated on three diverse tasks:
- Influence Estimation: Using the Linear Data Modeling Score (LDS) on CIFAR-10 and nuScenes, the proposed dZw (with ZCA whitening for this specific comparison to align with TRAK's evaluation) outperformed TRAK, indicating better influence estimation. Qualitative results (Fig. 3) show dZw identifies semantically relevant helpful/detracting examples.
- Semantic Segmentation:
- Task: Selecting synthetic data from GTA5 (candidate) to improve performance on real-world Cityscapes (target) using DeepLabV3.
- Results: TAROT outperformed baselines (LESS, DsDm, Random). TAROT-OTM selected ~24% of GTA5 data and achieved higher mIoU than training on full GTA5 or other selection ratios. T-SNE visualizations (Fig. 5) showed TAROT captures the target distribution's complexity better than methods like DsDm.
- Motion Prediction:
- Task: Selecting data from WOMD, Argoverse 2, and nuPlan (candidates) to improve performance on nuScenes (target) using Wayformer, with data selection guided by a smaller model, AutoBots.
- Results: TAROT consistently outperformed baselines. TAROT-OTM selected ~29.8% of data and surpassed performance of training on the full candidate set. Wayformer trained on TAROT-selected data achieved 1st place on the nuScenes leaderboard (Table 1), demonstrating significant improvement with less data. This also highlighted TAROT's transferability (selection by a small model benefits a larger one).
- Instruction Tuning for LLMs:
- Task: Selecting data from a mix (Flan V2, CoT, etc.) for MMLU and BBH target tasks using Llama-3.1-8B and Qwen-2.5-7B.
- Results: TAROT outperformed baselines and full dataset training, without needing subtask labels (unlike LESS). TAROT-OTM was particularly efficient, selecting less than 0.5% of data while matching or exceeding the performance of methods using 5% data. Transferability across LLM architectures was also demonstrated.
Practical Implications:
- TAROT offers a robust method for curating datasets for specific tasks, potentially reducing training costs and improving model performance, especially when dealing with large, diverse, and potentially noisy candidate datasets.
- The whitened feature distance provides a more reliable measure of data similarity/influence in the context of model gradients.
- The OTM selection scheme can automatically determine a near-optimal subset size, removing the need for manual tuning of selection ratios.
- The framework is versatile, demonstrating strong performance across computer vision (segmentation, motion prediction) and NLP (instruction tuning).
- Data weighting based on OT potentials can further fine-tune the training process by emphasizing more critical samples.
Limitations:
- The paper notes that TAROT might overfit if the target dataset is very small or highly specific, potentially harming generalization. Future work could involve incorporating distribution diversification into the selection objective.
- The computational overhead of gradient extraction and repeated OT calculations might be considerable for extremely large datasets or models, though random projection helps mitigate some of this.