Multi-Term Adam (MTAdam)
- Multi-Term Adam (MTAdam) is an adaptive optimization algorithm that automatically rescales gradients from different loss terms to ensure balanced influence per layer.
- It uses per-term moment estimation and dynamic per-layer balancing to stabilize multi-objective training without manual loss-weight tuning.
- MTAdam aggregates updates with a max-based second moment for robust step-size control, demonstrating improved performance on multi-loss benchmarks.
Multi-Term Adam (MTAdam) is an adaptive optimization algorithm designed to address the challenges of dynamically balancing multiple loss terms in deep neural network training. Unlike standard Adam, which operates on a single overall loss, MTAdam automatically rescales the magnitude of gradients arising from each distinct loss term so that their influence is balanced per network layer and throughout training. This procedure removes the need for manual loss-weight tuning, improves robustness to poor initial loss weighting, and stabilizes optimization in multi-objective and adversarial settings (Malkiel et al., 2020).
1. Motivation and Problem Statement
Modern deep learning pipelines frequently involve composite objectives of the form
where each represents a distinct loss term and are scalar loss weights. Hand-tuning the loss weights is time-consuming, inflexible to changes during training, and may not generalize across architectures or datasets. Furthermore, different network layers can exhibit disparate sensitivity to gradients from each loss, so a single global weight is insufficient to ensure local balance. This is especially problematic for adversarial losses or in settings like multi-task learning and conditional GANs, where the optimal tradeoff between loss components is highly dynamic. MTAdam addresses these issues by enforcing that for every network layer, all loss terms produce gradients of (roughly) equal magnitude, while still leveraging Adam's adaptive moment estimation.
2. Notation and Per-Term Moment Estimation
Let be the loss terms, the full parameter vector, the raw gradient for term at time , and the -th parameter. Layers are indexed by , each being a subset of parameters.
For each loss term and parameter , MTAdam maintains separate first- and second-moment estimates analogous to those in Adam:
- First moment:
- Second moment:
- Bias correction:
Combining all loss terms with (and omitting later balancing) exactly recovers Adam.
3. Per-Layer Dynamic Balancing
To compare the impact of different losses at the layer level, MTAdam computes per-layer -norms of each term's gradient:
and maintains exponentially weighted averages for each:
Initialization uses to avoid division by zero. Dynamic balancing coefficients are constructed to locally rescale each loss term so that, after rescaling, all losses' gradients have similar norm:
By anchoring to term , MTAdam maintains for all at every layer .
4. Aggregated Parameter Update Mechanism
For each parameter, MTAdam forms a single update by aggregating the per-loss contributions. Importantly, the second moment used in the denominator is taken as the maximum across all loss terms for each parameter, providing robust step-size control in high-variance regimes:
and the parameter update is
or equivalently in vector form,
This "worst-case" denominator prevents overshooting in any direction where gradients are volatile under any loss term.
5. Hyperparameter Recommendations
MTAdam reuses canonical Adam hyperparameters for ease of adoption by practitioners and compatibility with existing setups:
- : base step size (typical Adam value, e.g., for CNNs)
- : first moment decay rate
- : second moment decay rate
- : per-layer magnitude decay rate (set to )
- : denominator stability
This design ensures MTAdam introduces minimal new tuning overhead beyond standard Adam settings.
6. Empirical Evaluation Results
MTAdam was evaluated on:
- A controlled ten-term unbalanced MNIST classifier (with each digit as a separate loss) where randomly sampled per-class weights in were used. MTAdam achieved accuracy and avoided the underfitting of low-weight classes observed with traditional Adam, RMSProp, or SGD using the same (unbalanced) weights.
- Image-to-image translation (pix2pix, CycleGAN) and super-resolution (SRGAN), where MTAdam, initialized with uniform loss weights, consistently matched or outperformed the best hand-tuned Adam baselines from respective papers. Competing optimizers with unbalanced weights suffered from either mode collapse (GANs) or degraded FID, PSNR, and SSIM scores.
- Ablation studies confirmed that the per-layer dynamic re-scaling, use of the first-term anchor, and the max-variance denominator are each essential to MTAdam's stability and convergence.
7. Relationship to Other Adam Variants and Extensions
MTAdam generalizes Adam by introducing per-term, per-parameter moment tracking, dynamic per-layer balancing, and a conservative variance-based denominator for the update rule (Malkiel et al., 2020). Unlike alternative Adam extensions targeting generalization via regularization or integration-based smoothing—such as Multiple Integral Adam (MIAdam) (Jin et al., 2024)—MTAdam's primary focus is on multi-term loss handling and per-layer balance rather than improved generalization through flat minima promotion or noise filtering. Both lines employ moment manipulation, but the targets (loss balance versus landscape smoothing) and underlying mechanisms differ fundamentally.
MTAdam is positioned as a universal drop-in optimizer for multi-loss deep learning scenarios, dispensing with the need for manual loss-weight searches while retaining the convergence speed and adaptivity of Adam. Its instance-specific, dynamic, and per-layer balancing framework represents a distinct methodological advance in optimization for composite deep learning objectives (Malkiel et al., 2020).