Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
133 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
46 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Equinox: neural networks in JAX via callable PyTrees and filtered transformations (2111.00254v1)

Published 30 Oct 2021 in cs.LG and cs.PL

Abstract: JAX and PyTorch are two popular Python autodifferentiation frameworks. JAX is based around pure functions and functional programming. PyTorch has popularised the use of an object-oriented (OO) class-based syntax for defining parameterised functions, such as neural networks. That this seems like a fundamental difference means current libraries for building parameterised functions in JAX have either rejected the OO approach entirely (Stax) or have introduced OO-to-functional transformations, multiple new abstractions, and been limited in the extent to which they integrate with JAX (Flax, Haiku, Objax). Either way this OO/functional difference has been a source of tension. Here, we introduce Equinox', a small neural network library showing how a PyTorch-like class-based approach may be admitted without sacrificing JAX-like functional programming. We provide two main ideas. One: parameterised functions are themselves represented asPyTrees', which means that the parameterisation of a function is transparent to the JAX framework. Two: we filter a PyTree to isolate just those components that should be treated when transforming (jit',grad' or `vmap'-ing) a higher-order function of a parameterised function -- such as a loss function applied to a model. Overall Equinox resolves the above tension without introducing any new programmatic abstractions: only PyTrees and transformations, just as with regular JAX. Equinox is available at \url{https://github.com/patrick-kidger/equinox}.

Citations (88)

Summary

  • The paper introduces Equinox, a novel library that merges JAX’s functional style with PyTorch’s OO design via callable PyTrees.
  • Its methodology employs immutable class instances as parameterized functions, enabling seamless integration with JAX transformations like jit and grad.
  • The work simplifies neural network development by filtering PyTrees to selectively exclude non-differentiable parameters, enhancing model composability.

Equinox: Neural Networks in JAX via Callable PyTrees and Filtered Transformations

This paper introduces Equinox, a neural network library that adeptly combines the strengths of JAX's functional programming style and PyTorch's object-oriented syntax. The key contribution of Equinox lies in addressing the OO/functional tension present in existing JAX-based neural network libraries, without introducing unnecessary complexity. This is achieved by representing parameterized functions as PyTrees, maintaining the benefits of functional programming while incorporating the readability of a class-based syntax.

Background and Motivation

JAX is a widely-used Python framework for autodifferentiation, known for its reliance on pure functions and functional programming. Central to JAX are two concepts: PyTrees and transformations such as jit, grad, and vmap. These transformations enable optimization and differentiation of functions, but the integration of an object-oriented approach, as popularized by PyTorch, presents challenges.

Previous libraries have attempted to reconcile these paradigms by either rejecting OO approaches, as seen in Stax, or introducing OO-to-functional transformations, multipurpose abstractions, and various limitations, as seen in libraries like Flax, Haiku, and Objax. Equinox proposes a novel solution, demonstrating that a PyTorch-like API can be successfully employed without sacrificing the functional paradigm inherent in JAX.

Equinox: A Unified Approach

Equinox's primary contribution is its use of callable PyTrees, allowing class instances to represent parameterized functions. These instances are immutable, consistent with JAX's principles. Each class, corresponding to a function family, is registered as a custom PyTree node type. This representation ensures that parameterization is transparent to JAX, enabling direct use of JAX transformations.

This approach circumvents the need for complex abstractions and OO-to-functional transforms. Additionally, Equinox introduces two foundational ideas:

  1. Parameterized Functions as Data: Functions are represented as immutable class instances, integrating seamlessly with JAX’s transformations. This approach prioritizes simplicity and composability, offering an elegant syntax without introducing new abstractions.
  2. Filtering PyTrees: Filtering resolves issues with parameter types not recognized by JAX, allowing certain parameters to be excluded from autodifferentiation or JIT compilation. This is particularly useful for handling non-differentiable parameters or frozen parts of a model.

Practical and Theoretical Implications

Equinox's simplicity and compatibility with JAX make it a notable choice for researchers and practitioners seeking to combine object-oriented design with functional programming benefits. Its approach ensures easy integration with existing JAX projects, minimizing learning curves and user errors.

Theoretically, Equinox underscores the feasibility of blending seemingly contradictory paradigms by treating parameterized functions as data. This establishes a foundation for further research and development of hybrid programming models within autodifferentiation frameworks.

Future Developments

Given its innovative yet straightforward approach, Equinox sets the stage for further exploration into more advanced model building techniques and frameworks that leverage callable PyTrees. Future developments might include:

  • Enhanced support for complex models that utilize diverse data types.
  • Streamlined integration with other libraries to expand JAX's functional ecosystem.
  • Exploration of performance optimizations specific to the Equinox design.

Ultimately, while Equinox offers a compelling solution to a longstanding design challenge in JAX, it also opens avenues for the development of more versatile and user-friendly autodifferentiation frameworks.