DrJAX: Scalable and Differentiable MapReduce Primitives in JAX (2403.07128v2)
Abstract: We present DrJAX, a JAX-based library designed to support large-scale distributed and parallel machine learning algorithms that use MapReduce-style operations. DrJAX leverages JAX's sharding mechanisms to enable native targeting of TPUs and state-of-the-art JAX runtimes, including Pathways. DrJAX embeds building blocks for MapReduce computations as primitives in JAX. This enables three key benefits. First, DrJAX computations can be translated directly to XLA HLO, enabling flexible integration with a wide array of ML training platforms. Second, DrJAX computations are fully differentiable. Last, DrJAX computations can be interpreted out to existing batch-processing compute systems, including traditional MapReduce systems like Apache Beam and cross-device compute systems like those powering federated learning applications. We show that DrJAX provides an easily programmable, performant, and scalable framework for parallelized algorithm development. DrJAX is available at \url{https://github.com/google-research/google-research/tree/master/drjax}.
- TensorFlow: A system for large-scale machine learning. In Kimberly Keeton and Timothy Roscoe, editors, 12th USENIX Symposium on Operating Systems Design and Implementation, OSDI 2016, Savannah, GA, USA, November 2-4, 2016, pages 265–283. USENIX Association, 2016. URL https://www.usenix.org/conference/osdi16/technical-sessions/presentation/abadi.
- Pathways: Asynchronous distributed dataflow for ML. In Diana Marculescu, Yuejie Chi, and Carole-Jean Wu, editors, Proceedings of Machine Learning and Systems 2022, MLSys 2022, Santa Clara, CA, USA, August 29 - September 1, 2022. mlsys.org, 2022. URL https://proceedings.mlsys.org/paper/2022/hash/98dce83da57b0395e163467c9dae521b-Abstract.html.
- Friedrich L Bauer. Computational graphs and rounding error. SIAM Journal on Numerical Analysis, 11(1):87–96, 1974.
- Automatic differentiation in machine learning: a survey. Journal of Machine Learning Research, 18(153):1–43, 2018. URL http://jmlr.org/papers/v18/17-468.html.
- Flower: A friendly federated learning research framework. CoRR, abs/2007.14390, 2020. URL https://arxiv.org/abs/2007.14390.
- Towards federated learning at scale: System design. In Ameet Talwalkar, Virginia Smith, and Matei Zaharia, editors, Proceedings of Machine Learning and Systems 2019, MLSys 2019, Stanford, CA, USA, March 31 - April 2, 2019. mlsys.org, 2019. URL https://proceedings.mlsys.org/book/271.pdf.
- Practical secure aggregation for privacy-preserving machine learning. In proceedings of the 2017 ACM SIGSAC Conference on Computer and Communications Security, pages 1175–1191, 2017.
- JAX: composable transformations of Python+NumPy programs, 2018. URL http://github.com/google/jax.
- FL_PyTorch: Optimization research simulator for federated learning. In Proceedings of the 2nd ACM International Workshop on Distributed Machine Learning, pages 1–7, 2021.
- Iterated vector fields and conservatism, with applications to federated learning. In Sanjoy Dasgupta and Nika Haghtalab, editors, Proceedings of The 33rd International Conference on Algorithmic Learning Theory, volume 167 of Proceedings of Machine Learning Research, pages 130–147. PMLR, 29 Mar–01 Apr 2022. URL https://proceedings.mlr.press/v167/charles22a.html.
- Federated select: A primitive for communication-and memory-efficient federated learning. arXiv preprint arXiv:2208.09432, 2022.
- Towards federated foundation models: Scalable dataset pipelines for group-structured learning. In Alice Oh, Tristan Naumann, Amir Globerson, Kate Saenko, Moritz Hardt, and Sergey Levine, editors, Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023, 2023. URL http://papers.nips.cc/paper_files/paper/2023/hash/662bb9c4dcc96aeaac8e7cd3fc6a0add-Abstract-Datasets_and_Benchmarks.html.
- The DeepMind JAX Ecosystem, 2020. URL http://github.com/google-deepmind.
- FLUTE: A scalable, extensible framework for high-performance federated learning simulations. CoRR, abs/2203.13789, 2022. doi: 10.48550/ARXIV.2203.13789. URL https://doi.org/10.48550/arXiv.2203.13789.
- Diloco: Distributed low-communication training of language models. CoRR, abs/2311.08105, 2023. doi: 10.48550/ARXIV.2311.08105. URL https://doi.org/10.48550/arXiv.2311.08105.
- Model-agnostic meta-learning for fast adaptation of deep networks. In International conference on machine learning, pages 1126–1135. PMLR, 2017.
- Fedml: A research library and benchmark for federated machine learning, 07 2020.
- Introducing Tensorflow Federated, Mar 2019. URL https://blog.tensorflow.org/2019/03/introducing-tensorflow-federated.html.
- Improving federated learning personalization via model agnostic meta learning. CoRR, abs/1909.12488, 2019. URL http://arxiv.org/abs/1909.12488.
- PopulAtion Parameter Averaging (PAPA). arXiv preprint arXiv:2304.03094, 2023.
- Advances and open problems in federated learning. Foundations and trends® in machine learning, 14(1–2):1–210, 2021.
- FedScale: Benchmarking model and system performance of federated learning at scale. In International Conference on Machine Learning (ICML), 2022.
- GShard: Scaling giant models with conditional computation and automatic sharding. CoRR, abs/2006.16668, 2020. URL https://arxiv.org/abs/2006.16668.
- Branch-train-merge: Embarrassingly parallel training of expert language models. arXiv preprint arXiv:2208.03306, 2022.
- FATE: an industrial grade platform for collaborative learning with data protection. Journal of Machine Learning Research, 22(226):1–6, 2021. URL http://jmlr.org/papers/v22/20-815.html.
- Communication-efficient learning of deep networks from decentralized data. In Aarti Singh and Xiaojin (Jerry) Zhu, editors, Proceedings of the 20th International Conference on Artificial Intelligence and Statistics, AISTATS 2017, 20-22 April 2017, Fort Lauderdale, FL, USA, volume 54 of Proceedings of Machine Learning Research, pages 1273–1282. PMLR, 2017. URL http://proceedings.mlr.press/v54/mcmahan17a.html.
- Learning differentially private recurrent language models. In 6th International Conference on Learning Representations, ICLR 2018, Vancouver, BC, Canada, April 30 - May 3, 2018, Conference Track Proceedings. OpenReview.net, 2018. URL https://openreview.net/forum?id=BJ0hF1Z0b.
- The grand illusion: The myth of software portability and implications for ML progress. In Alice Oh, Tristan Naumann, Amir Globerson, Kate Saenko, Moritz Hardt, and Sergey Levine, editors, Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023, 2023. URL http://papers.nips.cc/paper_files/paper/2023/hash/42c40aff7814e9796266e12053b1c610-Abstract-Conference.html.
- Pytorch: An imperative style, high-performance deep learning library. In Hanna M. Wallach, Hugo Larochelle, Alina Beygelzimer, Florence d’Alché-Buc, Emily B. Fox, and Roman Garnett, editors, Advances in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019, December 8-14, 2019, Vancouver, BC, Canada, pages 8024–8035, 2019. URL https://proceedings.neurips.cc/paper/2019/hash/bdbca288fee7f92f2bfa9f7012727740-Abstract.html.
- Federated evaluation and tuning for on-device personalization: System design & applications. CoRR, abs/2102.08503, 2021. URL https://arxiv.org/abs/2102.08503.
- Adaptive federated optimization. In International Conference on Learning Representations, 2020.
- FedJAX: Federated learning simulation with JAX. arXiv preprint arXiv:2108.02117, 2021.
- Federated automatic differentiation. arXiv preprint arXiv:2301.07806, 2023.
- GSPMD: general and scalable parallelization for ML computation graphs. CoRR, abs/2105.04663, 2021. URL https://arxiv.org/abs/2105.04663.
- Pysyft: A library for easy federated learning. Federated learning systems: Towards next-generation AI, pages 111–139, 2021.