Task-Relevant Parameter & Token Selection (TR‑PTS)
- Task-Relevant Parameter and Token Selection (TR‑PTS) is a framework that improves model efficiency by selectively tuning high-impact parameters and tokens based on statistical criteria.
- It uses Fisher Information to identify critical parameters and employs attention scores to select informative tokens, ensuring targeted updates and reduced computational cost.
- Empirical results show significant gains in accuracy on vision benchmarks while updating a minimal fraction of model parameters and tokens.
Task-Relevant Parameter and Token Selection (TR‑PTS) refers to a principled framework for improving the efficiency and accuracy of large neural models, particularly Vision Transformers (ViTs), by selectively fine-tuning just the most task-informative parameters and dynamically selecting or merging only the most discriminative input tokens. Unlike conventional parameter-efficient fine-tuning (PEFT) or uniform token reduction, TR‑PTS is explicitly task-driven, leveraging information-theoretic criteria and attention statistics to concentrate model capacity on task-relevant subspaces. This approach directly addresses computational, memory, and inference bottlenecks in downstream adaptation of large pre-trained models by focusing updates on those parameters and tokens that statistically contribute most to task performance.
1. Task-Relevant Parameter Selection via Fisher Information
A central innovation in TR‑PTS is the use of the Fisher Information Matrix (FIM) to quantify each parameter’s importance for the target task. Given a model with parameters θ, the FIM is defined as
and is typically approximated via the diagonal of the squared gradients of the cross-entropy loss: Parameters with high FIM diagonal entries are those which, when perturbed, induce large changes in model output, and thus are most relevant to task adaptation.
TR‑PTS implements a layer-wise selection scheme. First, the top M% of parameters with the highest FIM values are selected, yielding a candidate set. Layer l’s importance is then computed as the fraction of these top parameters in that layer: For each layer, the set of tunable connections per neuron is given by
and each neuron’s set of updated parameters is selected according to the largest FIM entries. This allocation ensures that task-discriminative parameters anywhere in the model are prioritized for tuning, while all others are frozen, thus reducing both memory and computational footprint.
2. Task-Relevant Token Selection Scheme
TR‑PTS incorporates a token selection mechanism designed for ViTs. The selection is guided by the attention scores from the [CLS] token, which aggregates global semantic information. For each patch token (with key ), the attention score relative to the [CLS] token (with query ) is
A selection ratio is specified, and the top tokens with the highest scores are preserved, denoted .
Instead of discarding the remaining tokens (set ), TR‑PTS merges them through a weighted average: yielding a compact sequence
which is passed to all subsequent layers. This process provides substantial reductions in FLOPs and memory while retaining or even enhancing task accuracy compared to simple downsampling, as redundant spatial information is aggregated rather than discarded.
3. Joint Parameter and Token Optimization
TR‑PTS distinguishes itself by co-optimizing parameter selection and token selection. After each forward pass with the refined tokens, only the task-relevant parameters (as defined by the FIM allocation and captured in a binary mask ) are updated: where denotes elementwise multiplication. This joint strategy allows the model to iteratively emphasize and refine the representation of task-discriminative features within both token and parameter subspaces. In effect, the parameter mask and the refined tokens reinforce each other during fine-tuning, yielding improved adaptation with dramatically reduced resource usage.
4. Empirical Performance and Efficiency
TR‑PTS reports state-of-the-art results on several standard vision benchmarks. On VTAB-1k, TR‑PTS attains a top-1 accuracy of 75.92%, which is a 10.35% improvement over full fine-tuning while updating only 0.34% of the model parameters. On FGVC datasets, average accuracy reaches 91.94%, outperforming full fine-tuning by 3.40% and other parameter-efficient methods with significantly less compute and storage requirement.
These gains are attributable to TR‑PTS’s selective update and preservation mechanisms: computational and memory costs are lowered due to both the minimized set of tunable connections and the smaller set of active tokens in major computation blocks. The approach is also robust, as merging redundant tokens via attention-weighted averaging maintains information fidelity.
5. Algorithmic Summary Table
Component | Selection Criterion | Resulting Effect |
---|---|---|
Parameter Selection | Top FIM diagonal entries per layer | Only most sensitive connections tuned |
Token Selection | Highest [CLS]-attention scores | Most informative tokens preserved; rest merged |
Joint Optimization | Interleaved via mask and tokens | Task-driven efficiency and accuracy increase |
6. Broader Relevance and Code Availability
TR‑PTS provides a general paradigm for efficient, task-driven fine-tuning in large neural architectures, directly addressing the limitations of task-agnostic PEFT and simple token dropping. The open-source implementation is available at https://github.com/synbol/TR-PTS, facilitating reproduction and application to various downstream settings.
The principal significance of TR‑PTS lies in its demonstration that a principled, information-theoretic framework for combined parameter and token selection can yield models that are both more accurate and more efficient than full model tuning or uniform reduction. The framework’s modularity also suggests compatibility with future advances in parameter-efficient tuning and adaptive input sampling strategies.