Conjugate-Gradient Approximated Influence Functions
- Conjugate-Gradient Approximated Influence Functions are techniques that compute inverse Hessian–vector products via CG, addressing high computational costs in large models.
- They leverage parameter subspace restriction and automatic differentiation to efficiently estimate the impact of individual training samples on model predictions.
- Application in dataset pruning has shown a modest improvement in validation accuracy by removing harmful training examples identified through curvature-informed influence scores.
Conjugate-Gradient Approximated Influence Functions are a class of computational techniques for efficiently estimating the effect of individual training samples on the predictions or parameters of complex machine learning models, particularly when second-order curvature information is required but direct matrix inversion is computationally infeasible. These methods utilize the conjugate gradient (CG) algorithm to compute approximate inverse Hessian–vector products (IHVPs), a critical component in the calculation of influence functions, especially in the context of large-scale or parameter-efficient fine-tuning settings (Fein et al., 18 Jul 2025).
1. Theoretical Foundation and Motivation
Influence functions, as introduced in robust statistics, measure the sensitivity of an estimator to infinitesimal perturbations of the data distribution. In machine learning, they are often used to assess how a small change in, or removal of, a training point affects validation loss or model predictions. The canonical influence function for a training example with respect to a validation example is given by: where is the loss function, the model parameters, and the Hessian of the empirical risk. Direct computation requires inverting the (potentially massive) Hessian matrix .
Conjugate-gradient approximated influence functions address the challenge of large-dimensional Hessian inversion by leveraging iterative linear solvers. Specifically, CG is employed to solve the system for a given vector , where is a Tikhonov damping parameter ensuring positive definiteness. The solution is then used within the influence formula (Fein et al., 18 Jul 2025).
2. Computational Workflow and Practical Implementation
The practical computation of conjugate-gradient approximated influence functions involves several steps:
- Parameter Subspace Restriction: To make the approach tractable in high-dimensional neural networks, calculations are often confined to a learned subspace, such as the LoRA (Low-Rank Adaptation) parameters. This reduces the number of variables involved (e.g., to 0.12% of the full model) (Fein et al., 18 Jul 2025).
- Hessian–Vector Products: Explicit Hessian formation is avoided. Instead, one uses automatic differentiation (notably Pearlmutter's "double-backprop" trick) to compute Hessian–vector products as required by CG iterations.
- Iterative Solution via CG: The (damped) linear system is solved iteratively. At each iteration, standard CG updates are performed:
where is the residual, the search direction, and the Hessian–vector product plus damping term.
- Influence Aggregation: For a given validation point , solve for . Compute the influence on each training point as .
- Batch Processing and Pruning: Aggregate influence values across the validation set; sort training examples by mean influence to prioritize data curation activities such as pruning (Fein et al., 18 Jul 2025).
The overall method enables influence computation even for modern LLMs with very large parameter counts, provided subspace restriction and efficient Hessian-vector products are employed.
3. Application to Preference Dataset Pruning
A core application presented in recent work is the use of conjugate-gradient approximated influence functions to filter and improve the quality of training data for reward models—in particular, for human preference datasets used in alignment and fine-tuning tasks (Fein et al., 18 Jul 2025).
- Use Case: In the TL;DR dataset (a set of human-labeled summary preferences), influence scores are used to quantify the effect of each training example on validation accuracy.
- Pruning Strategy: By removing training examples with the highest mean positive influence scores (i.e., those estimated to harm validation loss), performance on a held-out validation set is marginally improved (1.5% pairwise accuracy uplift after removing 10% of data).
- Comparative Analysis: Influence-based removal outperforms random pruning but is less effective than “gradient similarity” (a first-order approximation) for identifying helpful examples. However, curvature-informed influence functions are distinctly better at identifying harmful examples.
This demonstrates how influence-based dataset pruning, made feasible by CG approximations, can refine noisy datasets for better downstream model performance.
4. Computational and Methodological Considerations
Several computational properties and practical issues are associated with this approach:
- Damping and Numerical Stability: The inclusion of in the Hessian ensures the system remains positive definite, preventing divergence or ill-conditioning in CG.
- Iteration and Batch Sizing: The convergence of CG depends on the Hessian’s conditioning and chosen tolerance; typical settings involve tens to hundreds of iterations per influence evaluation, though efficiency gains accrue when using parameter-efficient tuning schemes.
- Gradient Computation: Batched computation of gradients and Hessian-vector products is used to amortize the cost across samples, supporting scalable implementation with modern auto-diff frameworks.
- Ranking and Aggregation: Influence values are typically mean-aggregated over all validation samples, yielding a per-training-example “harmfulness” score suitable for automated filtering.
5. Empirical Findings and Limitations
Reported experimental findings illustrate:
- Retraining Impact: Removing the top 10% most harmful training examples (by influence) yields a modest but consistent validation accuracy gain (1.5% on held-out evaluation) (Fein et al., 18 Jul 2025).
- Gradient Similarity vs. Influence: While gradient similarity excels at identifying helpful data, it underperforms for harmful data. Influence functions, incorporating curvature, are better tuned to detect detrimental examples.
- Curvature Importance: The differential utility of influence functions for harmful vs. helpful data implies that regions of high curvature (steep local loss landscape) are of special importance for data filtering, suggesting that incorporating second-order information is crucial in these contexts.
6. Implications and Future Directions
The use of conjugate-gradient approximated influence functions for dataset curation suggests that:
- Curvature Information Is Critical: Influence-based methods exploit curvature to better identify harmful data, whereas first-order methods capture overall alignment better for helpful examples.
- Scalability: Practical adoption in large models relies on efficient subspace restriction, batched computations, and damping techniques.
- Methodological Extensions: Future investigations may probe alternative second-order approximations, hybrid influence-gradient metrics, or application to larger and more diverse datasets to enhance generalizability.
These methods facilitate more robust and targeted dataset pruning strategies in the era of large pretrained models and parameter-efficient fine-tuning, with clear empirical support for their impact in practice (Fein et al., 18 Jul 2025).
Sponsored by Paperpile, the PDF & BibTeX manager trusted by top AI labs.
Get 30 days free