Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
149 tokens/sec
GPT-4o
9 tokens/sec
Gemini 2.5 Pro Pro
47 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

TAROT: Targeted Data Selection via Optimal Transport (2412.00420v2)

Published 30 Nov 2024 in cs.LG, cs.CV, and stat.ML

Abstract: We propose TAROT, a targeted data selection framework grounded in optimal transport theory. Previous targeted data selection methods primarily rely on influence-based greedy heuristics to enhance domain-specific performance. While effective on limited, unimodal data (i.e., data following a single pattern), these methods struggle as target data complexity increases. Specifically, in multimodal distributions, these heuristics fail to account for multiple inherent patterns, leading to suboptimal data selection. This work identifies two primary factors contributing to this limitation: (i) the disproportionate impact of dominant feature components in high-dimensional influence estimation, and (ii) the restrictive linear additive assumptions inherent in greedy selection strategies. To address these challenges, TAROT incorporates whitened feature distance to mitigate dominant feature bias, providing a more reliable measure of data influence. Building on this, TAROT uses whitened feature distance to quantify and minimize the optimal transport distance between the selected data and target domains. Notably, this minimization also facilitates the estimation of optimal selection ratios. We evaluate TAROT across multiple tasks, including semantic segmentation, motion prediction, and instruction tuning. Results consistently show that TAROT outperforms state-of-the-art methods, highlighting its versatility across various deep learning tasks. Code is available at https://github.com/vita-epfl/TAROT.

Summary

  • The paper introduces TAROT, a framework that uses an improved whitened feature distance to select a subset of candidate data by minimizing the optimal transport distance.
  • It employs gradient embedding, dimensionality reduction, and Cholesky whitening to overcome biases in traditional influence estimation methods.
  • Experimental validations show TAROT’s effectiveness in semantic segmentation, motion prediction, and instruction tuning, achieving superior performance with reduced data.

This paper introduces TAROT (Targeted Data Selection via Optimal Transport), a framework designed to select the most relevant data from a large candidate pool to improve model performance on a specific target domain. The authors identify limitations in previous targeted data selection methods, which often rely on influence-based greedy heuristics. These heuristics can be suboptimal, especially when target data distributions are complex and multimodal, due to (i) the disproportionate impact of dominant features in high-dimensional influence estimation and (ii) restrictive linear additive assumptions in greedy selection strategies.

TAROT addresses these challenges by first proposing an improved influence estimation metric called whitened feature distance (dZwd^w_{\mathcal{Z}}). This metric is designed to mitigate bias from dominant feature components. The process involves several steps:

  1. Gradient Embedding: For each data sample z=(x,y)z = (x, y), an embedding ϕ(z)\phi(z) is created by summing its loss gradients L(z;θi)\nabla \mathcal{L}(z; \theta_i) with respect to model parameters θi\theta_i over several training checkpoints.

    ϕ(z)=i=1TL(z;θi)\phi(z) = \sum_{i=1}^{T} \nabla \mathcal{L}(z; \theta_i)

  2. Dimensionality Reduction: High-dimensional gradients are projected to a lower-dimensional space using random projections.

    $\phi(z)^{\text{proj}_i = \mathcal{P}^{T} \cdot \phi(z)_i$

  3. Whitening: The projected gradients are centered, and then Cholesky whitening is applied. This involves computing the covariance matrix Σ\Sigma of the centered projected gradients, performing Cholesky decomposition (Σ=LL\Sigma = LL^\top), and then transforming the gradients:

    ϕ(z)w=L1ϕ~(z)proj\phi(z)^w = L^{-1} \tilde{\phi}(z)^{\text{proj}}

    This step decorrelates features and scales them to have unit variance.

  4. Normalization: Each whitened gradient vector is normalized to unit length:

    ϕ^(z)i=ϕ(z)iwϕ(z)iw2\hat{\phi}(z)_i = \frac{\phi(z)^w_i}{||\phi(z)^w_i||_2}

  5. Distance Calculation: The whitened feature distance between two samples zz and zz' is the L2 norm of the difference between their normalized, whitened gradient embeddings:

    dZw(z,z)=ϕ^(z)ϕ^(z)2d^w_{\mathcal{Z}}(z, z') = ||\hat{\phi}(z) - \hat{\phi}(z')||_2

This dZwd^w_{\mathcal{Z}} serves as the cost function c(x,y)c(x,y) for calculating the Optimal Transport (OT) distance between the distributions of the selected candidate data and the target data. The core idea of TAROT is to select a subset of candidate data Ds\mathcal{D}_s from the full candidate pool Dc\mathcal{D}_c such that the OT distance to the target dataset Dt\mathcal{D}_t is minimized. dOT(Ds,Dt)=minπΠ(αs,αt)Z×ZdZw(z,z)dπ(z,z)d_{\mathrm{OT}}(\mathcal{D}_s, \mathcal{D}_t) = \min _{\pi \in \Pi(\alpha_s, \alpha_t)} \int_{\mathcal{Z} \times \mathcal{Z}} d^w_{\mathcal{Z}}(z, z') d\pi(z, z'), where αs\alpha_s and αt\alpha_t are empirical distributions of Ds\mathcal{D}_s and Dt\mathcal{D}_t. The Sinkhorn algorithm is used for efficient OT computation.

Since solving the subset selection problem to minimize OT distance is combinatorial, TAROT employs a greedy approach. Two selection schemes are proposed:

  1. Fixed-Size Selection:
    • Aims to select a predetermined number of samples SS.
    • Iteratively, for k=1,2,k=1, 2, \dots, it considers adding the kk-th nearest candidate samples (based on dZwd^w_{\mathcal{Z}}) from Dc\mathcal{D}_c for each sample in Dt\mathcal{D}_t to the selected set Ds\mathcal{D}_s.
    • If adding all kk-th nearest neighbors exceeds SS, it uses the OT dual potential ϕi\phi_i to rank these candidates and selects the top SDsS - |\mathcal{D}_s| ones. The potential ϕi\phi_i estimates the benefit of adding sample ziz_i to Ds\mathcal{D}_s in terms of minimizing OT distance to Dt\mathcal{D}_t.
    • The pseudocode is available in Algorithm 1 of the supplementary material.
  2. OT-Distance Minimization Selection (OTM):
    • Aims to find an optimal selection ratio by directly minimizing the OT distance.
    • Uses kk-fold cross-validation on the target dataset Dt\mathcal{D}_t.
    • In each fold, $1/k$ of Dt\mathcal{D}_t is used for guiding selection (finding nearest neighbors from Dc\mathcal{D}_c), and the OT distance is evaluated against the remaining (k1)/k(k-1)/k of Dt\mathcal{D}_t.
    • Samples are iteratively added (nearest neighbors first). The process stops when adding more samples starts to increase the OT distance to the validation part of Dt\mathcal{D}_t.
    • The final selected set is the union of selections from all kk folds (typically k=10k=10).
    • The pseudocode is available in Algorithm 2 of the supplementary material.

Additionally, TAROT proposes Data Weighting with OT Potential. After selecting Ds\mathcal{D}_s, the OT potentials (Equation \ref{eq:potential}) for each selected sample ziDsz_i \in \mathcal{D}_s are calculated. These potentials are scaled to positive integer weights wiw_i, indicating how many times a sample should be repeated within a training epoch. The sum of weights wi\sum w_i can be set to a customizable repetition factor RR.

Implementation and Application:

  • Gradient Feature Extraction: Requires a pre-trained model on the candidate (or a mix of candidate and target) data to extract gradients. The paper uses gradients from later checkpoints. For example, in semantic segmentation, DeepLabV3 with a ResNet50 backbone was used, with gradients from the last four checkpoints.
  • Computational Cost: The main costs involve gradient computation, random projection, whitening (Cholesky decomposition of an N×NN \times N covariance matrix where NN is projected dimension), and OT distance calculation (Sinkhorn algorithm, typically O(NcNtlog())O(N_c N_t \log(\cdot)) for dense cost matrix or faster with specialized solvers if cost matrix is sparse, where Nc,NtN_c, N_t are dataset sizes). The greedy selection involves multiple OT computations.
  • Libraries: The random projection leverages efficient CUDA implementations (e.g., from TRAK). Standard libraries like NumPy/SciPy can be used for Cholesky decomposition, and POT (Python Optimal Transport) library for Sinkhorn algorithm.
  • Code: The authors provide code at https://github.com/vita-epfl/TAROT.

Experimental Validation:

TAROT was evaluated on three diverse tasks:

  1. Influence Estimation: Using the Linear Data Modeling Score (LDS) on CIFAR-10 and nuScenes, the proposed dZwd^w_{\mathcal{Z}} (with ZCA whitening for this specific comparison to align with TRAK's evaluation) outperformed TRAK, indicating better influence estimation. Qualitative results (Fig. 3) show dZwd^w_{\mathcal{Z}} identifies semantically relevant helpful/detracting examples.
  2. Semantic Segmentation:
    • Task: Selecting synthetic data from GTA5 (candidate) to improve performance on real-world Cityscapes (target) using DeepLabV3.
    • Results: TAROT outperformed baselines (LESS, DsDm, Random). TAROT-OTM selected ~24% of GTA5 data and achieved higher mIoU than training on full GTA5 or other selection ratios. T-SNE visualizations (Fig. 5) showed TAROT captures the target distribution's complexity better than methods like DsDm.
  3. Motion Prediction:
    • Task: Selecting data from WOMD, Argoverse 2, and nuPlan (candidates) to improve performance on nuScenes (target) using Wayformer, with data selection guided by a smaller model, AutoBots.
    • Results: TAROT consistently outperformed baselines. TAROT-OTM selected ~29.8% of data and surpassed performance of training on the full candidate set. Wayformer trained on TAROT-selected data achieved 1st place on the nuScenes leaderboard (Table 1), demonstrating significant improvement with less data. This also highlighted TAROT's transferability (selection by a small model benefits a larger one).
  4. Instruction Tuning for LLMs:
    • Task: Selecting data from a mix (Flan V2, CoT, etc.) for MMLU and BBH target tasks using Llama-3.1-8B and Qwen-2.5-7B.
    • Results: TAROT outperformed baselines and full dataset training, without needing subtask labels (unlike LESS). TAROT-OTM was particularly efficient, selecting less than 0.5% of data while matching or exceeding the performance of methods using 5% data. Transferability across LLM architectures was also demonstrated.

Practical Implications:

  • TAROT offers a robust method for curating datasets for specific tasks, potentially reducing training costs and improving model performance, especially when dealing with large, diverse, and potentially noisy candidate datasets.
  • The whitened feature distance provides a more reliable measure of data similarity/influence in the context of model gradients.
  • The OTM selection scheme can automatically determine a near-optimal subset size, removing the need for manual tuning of selection ratios.
  • The framework is versatile, demonstrating strong performance across computer vision (segmentation, motion prediction) and NLP (instruction tuning).
  • Data weighting based on OT potentials can further fine-tune the training process by emphasizing more critical samples.

Limitations:

  • The paper notes that TAROT might overfit if the target dataset is very small or highly specific, potentially harming generalization. Future work could involve incorporating distribution diversification into the selection objective.
  • The computational overhead of gradient extraction and repeated OT calculations might be considerable for extremely large datasets or models, though random projection helps mitigate some of this.
X Twitter Logo Streamline Icon: https://streamlinehq.com

Tweets