Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
102 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Conflict-Averse Gradient Descent for Multi-task Learning (2110.14048v2)

Published 26 Oct 2021 in cs.LG and cs.AI

Abstract: The goal of multi-task learning is to enable more efficient learning than single task learning by sharing model structures for a diverse set of tasks. A standard multi-task learning objective is to minimize the average loss across all tasks. While straightforward, using this objective often results in much worse final performance for each task than learning them independently. A major challenge in optimizing a multi-task model is the conflicting gradients, where gradients of different task objectives are not well aligned so that following the average gradient direction can be detrimental to specific tasks' performance. Previous work has proposed several heuristics to manipulate the task gradients for mitigating this problem. But most of them lack convergence guarantee and/or could converge to any Pareto-stationary point. In this paper, we introduce Conflict-Averse Gradient descent (CAGrad) which minimizes the average loss function, while leveraging the worst local improvement of individual tasks to regularize the algorithm trajectory. CAGrad balances the objectives automatically and still provably converges to a minimum over the average loss. It includes the regular gradient descent (GD) and the multiple gradient descent algorithm (MGDA) in the multi-objective optimization (MOO) literature as special cases. On a series of challenging multi-task supervised learning and reinforcement learning tasks, CAGrad achieves improved performance over prior state-of-the-art multi-objective gradient manipulation methods.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (44)
  1. Segnet: A deep convolutional encoder-decoder architecture for image segmentation. IEEE transactions on pattern analysis and machine intelligence, 39(12):2481–2495, 2017.
  2. Rich Caruana. Multitask learning. Machine learning, 28(1):41–75, 1997.
  3. Gradnorm: Gradient normalization for adaptive loss balancing in deep multitask networks. In International Conference on Machine Learning, pages 794–803. PMLR, 2018.
  4. Just pick a sign: Optimizing deep multitask models with gradient sign dropout. arXiv preprint arXiv:2010.06808, 2020.
  5. Li Deng. The mnist database of handwritten digit images for machine learning research. IEEE Signal Processing Magazine, 29(6):141–142, 2012.
  6. Jean-Antoine Désidéri. Multiple-gradient descent algorithm (mgda) for multiobjective optimization. Comptes Rendus Mathematique, 350(5-6):313–318, 2012.
  7. CVXPY: A Python-embedded modeling language for convex optimization. Journal of Machine Learning Research, 17(83):1–5, 2016.
  8. Divide-and-conquer reinforcement learning. arXiv preprint arXiv:1711.09874, 2017.
  9. Dynamic task prioritization for multitask learning. In Proceedings of the European Conference on Computer Vision (ECCV), pages 270–287, 2018.
  10. Soft actor-critic: Off-policy maximum entropy deep reinforcement learning with a stochastic actor. In International Conference on Machine Learning, pages 1861–1870. PMLR, 2018.
  11. A joint many-task model: Growing a neural network for multiple nlp tasks. arXiv preprint arXiv:1611.01587, 2016.
  12. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
  13. Pareto optimal redistribution. The American economic review, 59(4):542–557, 1969.
  14. Rotograd: Dynamic gradient homogenization for multi-task learning. arXiv preprint arXiv:2103.02631, 2021.
  15. Multi-task learning using uncertainty to weigh losses for scene geometry and semantics. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 7482–7491, 2018.
  16. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  17. Learning multiple layers of features from tiny images. 2009.
  18. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.
  19. Pareto multi-task learning. arXiv preprint arXiv:1912.12854, 2019.
  20. Towards impartial multi-task learning. In International Conference on Learning Representations, 2020.
  21. End-to-end multi-task learning with attention. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 1871–1880, 2019.
  22. Certified monotonic neural networks. arXiv preprint arXiv:2011.10219, 2020.
  23. Multi-task learning with user preferences: Gradient descent with controlled ascent in pareto optimization. In International Conference on Machine Learning, pages 6597–6607. PMLR, 2020.
  24. Attentive single-tasking of multiple tasks. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 1851–1860, 2019.
  25. Cross-stitch networks for multi-task learning. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 3994–4003, 2016.
  26. Actor-mimic: Deep multitask and transfer reinforcement learning. arXiv preprint arXiv:1511.06342, 2015.
  27. Routing networks: Adaptive selection of non-linear functions for multi-task learning. arXiv preprint arXiv:1711.01239, 2017.
  28. Sebastian Ruder. An overview of multi-task learning in deep neural networks. arXiv preprint arXiv:1706.05098, 2017.
  29. Policy distillation. arXiv preprint arXiv:1511.06295, 2015.
  30. Multi-task learning as multi-objective optimization. arXiv preprint arXiv:1810.04650, 2018.
  31. Auxiliary task reweighting for minimum-data learning. Advances in Neural Information Processing Systems, 33, 2020.
  32. Mtrl - multi task rl algorithms. Github, 2021.
  33. Multi-task reinforcement learning with context-based representations. arXiv preprint arXiv:2102.06177, 2021.
  34. Charles Stein. Inadmissibility of the usual estimator for the mean of a multivariate normal distribution. In Contribution to the Theory of Statistics, pages 197–206. University of California Press, 2020.
  35. Multi-task bayesian optimization. 2013.
  36. Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results. In Proceedings of the 31st International Conference on Neural Information Processing Systems, pages 1195–1204, 2017.
  37. Distral: Robust multitask reinforcement learning. arXiv preprint arXiv:1707.04175, 2017.
  38. Multi-task learning for dense prediction tasks: A survey. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2021.
  39. Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747, 2017.
  40. Multi-task reinforcement learning with soft modularization. arXiv preprint arXiv:2003.13661, 2020.
  41. Gradient surgery for multi-task learning. arXiv preprint arXiv:2001.06782, 2020.
  42. Meta-world: A benchmark and evaluation for multi-task and meta reinforcement learning. In Conference on Robot Learning, pages 1094–1100. PMLR, 2020.
  43. Taskonomy: Disentangling task transfer learning. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 3712–3722, 2018.
  44. Yu Zhang and Qiang Yang. A survey on multi-task learning. IEEE Transactions on Knowledge and Data Engineering, 2021.
User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (5)
  1. Bo Liu (484 papers)
  2. Xingchao Liu (28 papers)
  3. Xiaojie Jin (50 papers)
  4. Peter Stone (184 papers)
  5. Qiang Liu (405 papers)
Citations (237)

Summary

An Overview of Conflict-Averse Gradient Descent for Multi-task Learning

This paper introduces a novel approach, Conflict-Averse Gradient Descent (CAGrad), to address the challenges inherent in Multi-task Learning (MTL). The authors focus on the issue of conflicting gradients which arise when using traditional MTL optimization objectives that tend to minimize the average loss across tasks. Such conflicts often degrade performance, as gradients from different tasks can interfere destructively. The proposed CAGrad algorithm intelligently navigates this complex gradient landscape, seeking to enhance convergence behavior while optimizing multi-task objectives.

Main Contributions and Methodology

CAGrad is formulated to adjust the learning trajectory by minimizing the worst-case local improvement across all task gradients without compromising convergence to a minimum of the average loss. This technique systematically balances different task objectives and generalizes traditional methods like Gradient Descent (GD) and Multiple Gradient Descent Algorithm (MGDA), subsuming these as special cases under its broader framework.

The core algorithmic innovation in CAGrad is the exploitation of a decision vector that maximizes the minimum inner-product with any task gradient, subject to remaining within a specific distance from the average gradient. This strategy is computationally embedded within a dual formulation that efficiently optimizes a much lower-dimensional problem compared to the original high-dimensional parameter space.

Theoretical and Empirical Insights

The convergence analysis provided shows that for any specified constant 0c<10 \leq c < 1, CAGrad maintains the original mid-task objective of convergence to stationary points of the average loss L0L_0. It further shows robust performance across several benchmark datasets, often outperforming existing state-of-the-art methods on supervised, semi-supervised, and reinforcement learning tasks. The results substantiate the theoretical claims by demonstrating improved learning efficiency and task-specific performance with CAGrad.

Implications and Future Directions

Practically, CAGrad introduces a significant improvement in resource management within MTL contexts. It allows for agile handling of conflicting gradients, making it particularly useful in scenarios with highly non-linear models or large sets of tasks. Theoretically, CAGrad opens new avenues for exploring complex multi-objective optimization landscapes, especially in AI applications where tasks are interdependent.

The paper leaves room for future investigations into more generalized objective functions beyond the average loss framework. There is potential to explore adaptability in setting the main optimization targets based on dynamic task importance, which might enhance the practical applicability of CAGrad in more specialized or evolving environments.

Conclusion

By introducing a principled optimization approach that effectively mitigates the detrimental effects of conflicting gradients, CAGrad represents a significant advancement in multi-task learning. This method not only further solidifies the understanding of MTL optimization dynamics but also offers an efficient, theoretically sound, and empirically validated solution to a well-acknowledged problem. This work aligns with ongoing developments in AI, marking a step forward in responsive and adaptive learning systems.