Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
169 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
45 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

DrJAX: Scalable and Differentiable MapReduce Primitives in JAX (2403.07128v2)

Published 11 Mar 2024 in cs.DC and cs.LG

Abstract: We present DrJAX, a JAX-based library designed to support large-scale distributed and parallel machine learning algorithms that use MapReduce-style operations. DrJAX leverages JAX's sharding mechanisms to enable native targeting of TPUs and state-of-the-art JAX runtimes, including Pathways. DrJAX embeds building blocks for MapReduce computations as primitives in JAX. This enables three key benefits. First, DrJAX computations can be translated directly to XLA HLO, enabling flexible integration with a wide array of ML training platforms. Second, DrJAX computations are fully differentiable. Last, DrJAX computations can be interpreted out to existing batch-processing compute systems, including traditional MapReduce systems like Apache Beam and cross-device compute systems like those powering federated learning applications. We show that DrJAX provides an easily programmable, performant, and scalable framework for parallelized algorithm development. DrJAX is available at \url{https://github.com/google-research/google-research/tree/master/drjax}.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (35)
  1. TensorFlow: A system for large-scale machine learning. In Kimberly Keeton and Timothy Roscoe, editors, 12th USENIX Symposium on Operating Systems Design and Implementation, OSDI 2016, Savannah, GA, USA, November 2-4, 2016, pages 265–283. USENIX Association, 2016. URL https://www.usenix.org/conference/osdi16/technical-sessions/presentation/abadi.
  2. Pathways: Asynchronous distributed dataflow for ML. In Diana Marculescu, Yuejie Chi, and Carole-Jean Wu, editors, Proceedings of Machine Learning and Systems 2022, MLSys 2022, Santa Clara, CA, USA, August 29 - September 1, 2022. mlsys.org, 2022. URL https://proceedings.mlsys.org/paper/2022/hash/98dce83da57b0395e163467c9dae521b-Abstract.html.
  3. Friedrich L Bauer. Computational graphs and rounding error. SIAM Journal on Numerical Analysis, 11(1):87–96, 1974.
  4. Automatic differentiation in machine learning: a survey. Journal of Machine Learning Research, 18(153):1–43, 2018. URL http://jmlr.org/papers/v18/17-468.html.
  5. Flower: A friendly federated learning research framework. CoRR, abs/2007.14390, 2020. URL https://arxiv.org/abs/2007.14390.
  6. Towards federated learning at scale: System design. In Ameet Talwalkar, Virginia Smith, and Matei Zaharia, editors, Proceedings of Machine Learning and Systems 2019, MLSys 2019, Stanford, CA, USA, March 31 - April 2, 2019. mlsys.org, 2019. URL https://proceedings.mlsys.org/book/271.pdf.
  7. Practical secure aggregation for privacy-preserving machine learning. In proceedings of the 2017 ACM SIGSAC Conference on Computer and Communications Security, pages 1175–1191, 2017.
  8. JAX: composable transformations of Python+NumPy programs, 2018. URL http://github.com/google/jax.
  9. FL_PyTorch: Optimization research simulator for federated learning. In Proceedings of the 2nd ACM International Workshop on Distributed Machine Learning, pages 1–7, 2021.
  10. Iterated vector fields and conservatism, with applications to federated learning. In Sanjoy Dasgupta and Nika Haghtalab, editors, Proceedings of The 33rd International Conference on Algorithmic Learning Theory, volume 167 of Proceedings of Machine Learning Research, pages 130–147. PMLR, 29 Mar–01 Apr 2022. URL https://proceedings.mlr.press/v167/charles22a.html.
  11. Federated select: A primitive for communication-and memory-efficient federated learning. arXiv preprint arXiv:2208.09432, 2022.
  12. Towards federated foundation models: Scalable dataset pipelines for group-structured learning. In Alice Oh, Tristan Naumann, Amir Globerson, Kate Saenko, Moritz Hardt, and Sergey Levine, editors, Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023, 2023. URL http://papers.nips.cc/paper_files/paper/2023/hash/662bb9c4dcc96aeaac8e7cd3fc6a0add-Abstract-Datasets_and_Benchmarks.html.
  13. The DeepMind JAX Ecosystem, 2020. URL http://github.com/google-deepmind.
  14. FLUTE: A scalable, extensible framework for high-performance federated learning simulations. CoRR, abs/2203.13789, 2022. doi: 10.48550/ARXIV.2203.13789. URL https://doi.org/10.48550/arXiv.2203.13789.
  15. Diloco: Distributed low-communication training of language models. CoRR, abs/2311.08105, 2023. doi: 10.48550/ARXIV.2311.08105. URL https://doi.org/10.48550/arXiv.2311.08105.
  16. Model-agnostic meta-learning for fast adaptation of deep networks. In International conference on machine learning, pages 1126–1135. PMLR, 2017.
  17. Fedml: A research library and benchmark for federated machine learning, 07 2020.
  18. Introducing Tensorflow Federated, Mar 2019. URL https://blog.tensorflow.org/2019/03/introducing-tensorflow-federated.html.
  19. Improving federated learning personalization via model agnostic meta learning. CoRR, abs/1909.12488, 2019. URL http://arxiv.org/abs/1909.12488.
  20. PopulAtion Parameter Averaging (PAPA). arXiv preprint arXiv:2304.03094, 2023.
  21. Advances and open problems in federated learning. Foundations and trends® in machine learning, 14(1–2):1–210, 2021.
  22. FedScale: Benchmarking model and system performance of federated learning at scale. In International Conference on Machine Learning (ICML), 2022.
  23. GShard: Scaling giant models with conditional computation and automatic sharding. CoRR, abs/2006.16668, 2020. URL https://arxiv.org/abs/2006.16668.
  24. Branch-train-merge: Embarrassingly parallel training of expert language models. arXiv preprint arXiv:2208.03306, 2022.
  25. FATE: an industrial grade platform for collaborative learning with data protection. Journal of Machine Learning Research, 22(226):1–6, 2021. URL http://jmlr.org/papers/v22/20-815.html.
  26. Communication-efficient learning of deep networks from decentralized data. In Aarti Singh and Xiaojin (Jerry) Zhu, editors, Proceedings of the 20th International Conference on Artificial Intelligence and Statistics, AISTATS 2017, 20-22 April 2017, Fort Lauderdale, FL, USA, volume 54 of Proceedings of Machine Learning Research, pages 1273–1282. PMLR, 2017. URL http://proceedings.mlr.press/v54/mcmahan17a.html.
  27. Learning differentially private recurrent language models. In 6th International Conference on Learning Representations, ICLR 2018, Vancouver, BC, Canada, April 30 - May 3, 2018, Conference Track Proceedings. OpenReview.net, 2018. URL https://openreview.net/forum?id=BJ0hF1Z0b.
  28. The grand illusion: The myth of software portability and implications for ML progress. In Alice Oh, Tristan Naumann, Amir Globerson, Kate Saenko, Moritz Hardt, and Sergey Levine, editors, Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023, 2023. URL http://papers.nips.cc/paper_files/paper/2023/hash/42c40aff7814e9796266e12053b1c610-Abstract-Conference.html.
  29. Pytorch: An imperative style, high-performance deep learning library. In Hanna M. Wallach, Hugo Larochelle, Alina Beygelzimer, Florence d’Alché-Buc, Emily B. Fox, and Roman Garnett, editors, Advances in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019, December 8-14, 2019, Vancouver, BC, Canada, pages 8024–8035, 2019. URL https://proceedings.neurips.cc/paper/2019/hash/bdbca288fee7f92f2bfa9f7012727740-Abstract.html.
  30. Federated evaluation and tuning for on-device personalization: System design & applications. CoRR, abs/2102.08503, 2021. URL https://arxiv.org/abs/2102.08503.
  31. Adaptive federated optimization. In International Conference on Learning Representations, 2020.
  32. FedJAX: Federated learning simulation with JAX. arXiv preprint arXiv:2108.02117, 2021.
  33. Federated automatic differentiation. arXiv preprint arXiv:2301.07806, 2023.
  34. GSPMD: general and scalable parallelization for ML computation graphs. CoRR, abs/2105.04663, 2021. URL https://arxiv.org/abs/2105.04663.
  35. Pysyft: A library for easy federated learning. Federated learning systems: Towards next-generation AI, pages 111–139, 2021.
Citations (1)

Summary

  • The paper introduces FAX, a novel library that integrates federated learning primitives into JAX to enable scalable and efficient distributed ML.
  • It leverages JAX’s JIT compilation and sharding techniques to achieve near-constant weak scaling for training models up to 8 billion parameters.
  • FAX preserves data location details through its design, bridging research prototypes and production systems for practical federated learning deployment.

FAX: Integrating Federated Learning Primitives into JAX for Scalable Distributed ML

Introduction

The development and success of ML applications, particularly those involving large-scale computations, are significantly driven by the ability to distribute computations across multiple compute nodes. This paper introduces FAX, a software library that embeds federated learning (FL) computations into JAX, leveraging JAX's built-in primitives to provide scalable and efficient federated computations. Federated Learning is an ML paradigm where multiple clients collaborate to train a model without sharing raw data. The ability to perform such computations in data centers is crucial for accelerating FL research and enabling the application of FL algorithms in various settings. FAX has been designed with performance, scalability, and ease of programming in mind, ensuring computations can be effortlessly translated to XLA HLO and interpreted by production cross-device federated compute systems.

System Design

FAX reimagines federated computations by integrating them as primitives within JAX. This integration hinges on two observations: most FL computations resemble distributed ML workloads, and federated automatic differentiation (AD) can be achieved by tracking data placement. By treating data locations as first-class citizens, FAX effectively manages federated values, distinguishing between values placed on clients and those placed on servers. The library provides a suite of federated building blocks, such as federated broadcast, federated map, and federated sum, which form the backbone of FL algorithms. Notably, these building blocks are designed to preserve information about data locations, enabling the differentiation through federated computations while maintaining the integrity of data placements.

Implementation

FAX’s implementation focuses on representing federated values as JAX arrays with an added dimension to indicate placement, facilitating the use of JAX's primitives mechanism for federated computations. By encapsulating federated computations within JAX, FAX leverages its JIT compilation and AD capabilities, thereby improving data center performance and scalability. FAX also employs specific sharding annotations to guide compilers like GSPMD in optimally distributing computations across devices. This approach ensures that FAX not only supports efficient and scalable federated training of large models but also extends to a broader range of ML computations beyond FL.

Scalability and Efficiency

The paper presents empirical results demonstrating FAX's ability to enable efficient and scalable federated training of LLMs ranging from 350 million to 8 billion parameters. By effectively sharding computations and utilizing JAX's JIT capabilities, FAX achieves near-constant weak scaling performance, a key indicator of its ability to manage large-scale federated computations. Additionally, the data showcases FAX's superiority in JIT compilation optimization over naïve for-loop implementations and highlights the necessity of FAX’s internal sharding annotations for achieving optimal performance.

Integration with Production Systems

A significant advantage of FAX is its ability to preserve data location information in computations, facilitating the translation of federated computations into representations comprehensible by production federated learning systems. By leveraging JAX's primitives mechanism, FAX ensures that the structure of federated computations, including placement decisions and cross-machine communication patterns, is maintained. This capability allows for seamless interpretation of FAX computations by production systems, bridging the gap between research prototypes and deployable federated learning applications.

Future Directions

While FAX makes significant strides in integrating federated learning within JAX, potential areas for further development include extending federated AD to support non-linear communication primitives and expanding data placement strategies to encompass more complex hierarchical scenarios. Moreover, developing mature interpreters for translating FAX computations into formats compatible with specific production platforms represents an avenue for enhancing FAX’s applicability and utility in real-world federated learning scenarios.

Conclusion

FAX stands as a significant contribution to the field of distributed and federated machine learning, offering a scalable, efficient, and programmable framework for federated computations. By harnessing the power of JAX and adhering to a principled approach to data placement and automatic differentiation, FAX opens new possibilities for FL algorithm development and deployment. Its potential for accelerating research and facilitating the bridge to production systems underscores FAX's role as a pivotal tool in the advancement of federated learning technologies.