Papers
Topics
Authors
Recent
2000 character limit reached

Adaptive Gradient Surgery in Multi-Task Learning

Updated 25 November 2025
  • Adaptive Gradient Surgery is an optimization technique that dynamically removes conflicting gradient components and adaptively rescales updates to mitigate task interference.
  • It employs pairwise conflict-aware gradient projection and norm preservation to ensure stability and efficient convergence in multi-task deep learning.
  • AGS significantly improves performance in imbalanced and multi-objective settings by preventing destructive oscillations and model collapse.

Adaptive Gradient Surgery (AGS) is a class of optimization techniques for multi-task and multi-objective deep learning that stabilizes the convergence of shared-parameter models by dynamically modifying gradient updates to resolve conflicts between competing objectives. AGS extends the family of “gradient surgery” methods—such as PCGrad and related projection-based algorithms—by not only removing harmful components in conflicting gradients but also adaptively re-scaling the resulting update to preserve convergence rates and prevent model collapse. AGS is particularly effective in settings with severe class or signal imbalance, as in computational cytology, multi-task computer vision, and robust pretraining for downstream aggregation tasks (Acerbis et al., 18 Nov 2025).

1. Multi-Task Optimization and the Gradient Interference Problem

AGS addresses core challenges in multi-task learning, where a shared encoder fθ()f_\theta(\cdot) is trained to simultaneously minimize multiple objectives—often with highly disparate gradient directions and magnitudes. A canonical example is the SLAM-AGS framework for cytology, in which weakly supervised cluster-preserving contrastive losses are optimized on negative (healthy) tissue patches, while self-supervised contrastive (SimCLR-style) objectives are applied to positive (diseased) patches (Acerbis et al., 18 Nov 2025). The joint objective is

L(θ)=LSimilarity(θ)+LSimCLR(θ)L(\theta) = L_{\text{Similarity}}(\theta) + L_{\text{SimCLR}}(\theta)

with LSimilarityL_{\text{Similarity}} and LSimCLRL_{\text{SimCLR}} applying to disjoint patch subsets.

The central optimization challenge is destructive gradient interference: gradients g1=θL1\mathbf{g}_1 = \nabla_\theta L_1 and g2=θL2\mathbf{g}_2 = \nabla_\theta L_2 may have negative cosine similarity (i.e., g1,g2<0\langle \mathbf{g}_1, \mathbf{g}_2 \rangle < 0), so that naïve gradient summation leads to oscillations, signal cancellation, or premature dominance of one task. This phenomenon is widely observed in supervised and reinforcement learning settings, and is especially acute when witness rates are low or class distributions are highly skewed (Yu et al., 2020, Borsani et al., 6 Jun 2025).

2. Core Algorithm: Gradient Surgery with Adaptive Rescaling

AGS modifies the standard joint gradient update—gsum=g1+g2\mathbf{g}_{\text{sum}} = \mathbf{g}_1 + \mathbf{g}_2—via a two-stage process that ensures (i) removal of direct conflicts and (ii) norm preservation:

(a) Conflict-Aware Gradient Projection:

Each task gradient is examined for pairwise conflicts. If cos(gt,gu)<0\cos(\mathbf{g}_t,\mathbf{g}_u) < 0, the projection operation removes the component of gt\mathbf{g}_t along gu\mathbf{g}_u:

g~t=gtgt,gugu2gu\tilde{\mathbf{g}}_t = \mathbf{g}_t - \frac{\langle \mathbf{g}_t, \mathbf{g}_u \rangle}{\|\mathbf{g}_u\|^2} \mathbf{g}_u

Otherwise, g~t=gt\tilde{\mathbf{g}}_t = \mathbf{g}_t.

(b) Adaptive Norm Rescaling:

After conflict removal and aggregation gpc=g~1+g~2\mathbf{g}_{\text{pc}} = \tilde{\mathbf{g}}_1 + \tilde{\mathbf{g}}_2, the norm is compared to the pre-surgery gradient magnitude:

Nsum=gsum,Npc=gpcN_{\text{sum}} = \|\mathbf{g}_{\text{sum}}\|,\quad N_{\text{pc}} = \|\mathbf{g}_{\text{pc}}\|

If Npc<NsumN_{\text{pc}} < N_{\text{sum}}, AGS rescales the projected gradient to preserve the original update magnitude:

gout=gpcNsumNpc\mathbf{g}_{\text{out}} = \mathbf{g}_{\text{pc}} \cdot \frac{N_{\text{sum}}}{N_{\text{pc}}}

Parameter updates use this norm-preserved gradient: θθηgout\theta \leftarrow \theta - \eta \mathbf{g}_{\text{out}}. This prevents stagnation from over-shrinking projected gradients, a limitation of classical PCGrad (Yu et al., 2020).

This mechanism guarantees that only the deleterious (i.e., loss-increasing for other tasks) components are removed, while the net update direction and strength are preserved, yielding both stability and efficiency (Acerbis et al., 18 Nov 2025).

3. Algorithmic Structure and Pseudocode

The AGS optimization loop for two objectives is as follows (Acerbis et al., 18 Nov 2025):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
for minibatch in data:
    # Partition data for each loss
    compute z_i = head(f_theta(x_i)) for all views
    compute L1 = L_Similarity, L2 = L_SimCLR
    g1 = grad(L1); g2 = grad(L2)
    g_sum = g1 + g2
    for t in {1,2}:
        u = other task
        if inner_product(g_t, g_u) < 0:
            tilde_g_t = g_t - proj(g_t, g_u)
        else:
            tilde_g_t = g_t
    g_pc = tilde_g_1 + tilde_g_2
    if norm(g_pc) < norm(g_sum):
        g_out = g_pc * (norm(g_sum) / norm(g_pc))
    else:
        g_out = g_pc
    update theta using g_out

This procedure generalizes to KK tasks by iteratively applying pairwise projections or using matrix projection techniques, as in PCGrad and MGDA (Yu et al., 2020).

4. Theoretical and Empirical Stability Properties

AGS is designed to guarantee monotonic descent for all participating objectives in expectation (up to first order), clamping gradient conflicts without excessively shrinking learning signal. Unlike pure regularization or weighted sum approaches, AGS provides per-minibatch adaptivity: surgery and rescaling are performed on-the-fly, batch-wise, automatically tracking the evolving conflict dynamics as the network trains.

Empirically, application of AGS in SLAM-AGS pretraining yields smooth loss curves and avoids pathological dominance or collapse to a trivial regime, even with extreme data imbalances (witness rates down to 0.5%) (Acerbis et al., 18 Nov 2025). Direct stability benefits include:

  • Prevention of collapse to a single task solution
  • Avoidance of non-convergent oscillatory updates
  • Maintenance of learning efficiency by restoring original update norm after conflict removal
Method Conflict Removal Norm Preservation Computational Complexity
PCGrad (Yu et al., 2020) Pairwise projection, no rescale No Low
MGDA QP-constrained joint direction Indirect, via QP High
CAGrad, IMTL Adaptive weighting, shrinkage Only partial Moderate
AGS Pairwise projection, rescale Yes, explicit Low–Moderate

AGS’s combination of simplicity (PCGrad-style projection) and explicit update norm correction differentiates it from earlier gradient projection techniques that can suffer from slowdowns due to gradient shrinkage or from the computational expense of multi-objective quadratic solvers. In AGS, the update magnitude is always at least as large as the original pre-surgery composite gradient, unless no conflicts exist, in which case standard gradient descent applies (Acerbis et al., 18 Nov 2025).

6. Domain Applications and Measured Impact

The most prominent application of AGS to date is in Slide-Label-Aware Multi-task pretraining for computational cytology (Acerbis et al., 18 Nov 2025). In this domain, AGS enables robust pretraining even in regimes of extreme positive-instance sparsity:

  • On a bone-marrow cytology dataset, SLAM-AGS with AGS improved bag-level F1-Score by up to +40 points at 0.5% witness rate over non-surgical baselines.
  • Top 400 positive cell retrieval saw ∼20 point gain over both pure self-supervision and weak-label contrastive pretraining.
  • Gains are maximized under severe imbalance, confirming that gradient interference—not just class sparsity—can critically limit performance when left unaddressed.

More broadly, AGS methodology is applicable to reinforcement learning (Yu et al., 2020), computer vision (Borsani et al., 6 Jun 2025), deformable registration (Dou et al., 2023), robust LLM fine-tuning (Yi et al., 10 Aug 2025), and multimodal representation learning (Yan et al., 17 Nov 2025), where multi-loss or multi-task setups frequently exhibit destructive interference.

7. Extensions, Limitations, and Prospects

AGS belongs to a broader family of conflict-resolution optimization techniques that operate directly at the gradient level, as opposed to merely adjusting loss weights. Potential extensions include:

  • Generalization to K>2K>2 objectives via iterative pairwise or simultaneous null-space projections
  • Layer-wise or adapter-level application in multi-branch architectures
  • Adaptive thresholds or partial projections to provide graded conflict removal
  • Combination with similarity- or momentum-weighted schemes for settings with high task heterogeneity (Borsani et al., 6 Jun 2025)

Limitations include increased per-step computation compared to naïve gradient summation (though modest for small KK), and sensitivity to batchwise gradient variance. However, practical empirical results indicate that adaptive norm restoration mitigates the convergence slowdowns sometimes observed in previous surgery methods.

In summary, Adaptive Gradient Surgery provides a robust, mechanically interpretable optimization architecture for stable and efficient multi-objective training, enabling effective simultaneous learning in settings characterized by scarcity, imbalance, or fundamentally conflicting supervision signals (Acerbis et al., 18 Nov 2025).

Whiteboard

Follow Topic

Get notified by email when new papers are published related to Adaptive Gradient Surgery.