Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash 100 tok/s
Gemini 2.5 Pro 58 tok/s Pro
GPT-5 Medium 29 tok/s
GPT-5 High 29 tok/s Pro
GPT-4o 103 tok/s
GPT OSS 120B 480 tok/s Pro
Kimi K2 215 tok/s Pro
2000 character limit reached

A research framework for writing differentiable PDE discretizations in JAX (2111.05218v1)

Published 9 Nov 2021 in cs.LG and physics.comp-ph

Abstract: Differentiable simulators are an emerging concept with applications in several fields, from reinforcement learning to optimal control. Their distinguishing feature is the ability to calculate analytic gradients with respect to the input parameters. Like neural networks, which are constructed by composing several building blocks called layers, a simulation often requires computing the output of an operator that can itself be decomposed into elementary units chained together. While each layer of a neural network represents a specific discrete operation, the same operator can have multiple representations, depending on the discretization employed and the research question that needs to be addressed. Here, we propose a simple design pattern to construct a library of differentiable operators and discretizations, by representing operators as mappings between families of continuous functions, parametrized by finite vectors. We demonstrate the approach on an acoustic optimization problem, where the Helmholtz equation is discretized using Fourier spectral methods, and differentiability is demonstrated using gradient descent to optimize the speed of sound of an acoustic lens. The proposed framework is open-sourced and available at \url{https://github.com/ucl-bug/jaxdf}

Citations (7)
List To Do Tasks Checklist Streamline Icon: https://streamlinehq.com

Collections

Sign up for free to add this paper to one or more collections.

Summary

  • The paper introduces the jaxdf framework that decouples PDE formulation from discretization, enhancing flexibility and enabling automatic differentiation.
  • The paper demonstrates how operator composition across various discretization families, such as polynomial and Fourier methods, supports efficient gradient computation.
  • The paper validates the framework with numerical experiments, including acoustic lens optimization and heat equation solutions, underscoring its practical utility.

An Examination of a Framework for Differentiable PDE Discretizations in JAX

This paper presents a methodological framework for creating differentiable partial differential equation (PDE) discretizations using JAX, an advanced library that supports automatic differentiation (AD). The framework, referred to as jaxdf (JAX-Discretization Framework), seeks to decouple the mathematical formulation of a problem from its discretization, enabling the construction of a diverse library of differentiable operators useful in applications ranging from reinforcement learning to optimal control.

Framework Design and Implementation

The paper emphasizes the significance of differentiable simulators, which are paramount in a spectrum of fields such as inverse problems, system identification, and optimization under uncertainty. By enabling the computation of analytical gradients with respect to input parameters, differentiable simulators enhance the capability to embed simulators within broader machine learning models. Notably, current simulators are frequently confined to specific discretizations, limiting their flexibility and scope. This framework addresses these limitations by facilitating the rapid creation of customized representations and designs, including physics-driven neural network layers and bespoke physics loss functions, while maintaining the requisite speed and flexibility for research purposes.

At the core of the framework is a software model that enables the translation of a PDE, defined over continuous functions, into a program that manipulates finite numerical values—a process known as discretization. Operators, often nonlinear, are implemented as mappings between discretization families, which are essentially parametrized function spaces. By representing operators in terms of discretization families, the framework allows for easy swapping of discretizations, thereby offering crucial flexibility.

Operator Composition and Discretization Families

For clarity and mathematical rigor, the authors explore various examples of discretization families and operator compositions. Traditional polynomial representations and Fourier spectral methods exemplify the programmable discretization families within jaxdf. For instance, the derivative operator can be rendered into different discrete forms depending on the chosen family, such as polynomial, Fourier, or PINNs-based. This flexibility facilitates the definition of operators across different discretization spaces in a consistent, reliable manner.

The framework demonstrates the efficient and differentiated juxtaposition of operators. Through function composition, distinct operators can be chained seamlessly, leveraging JAX’s capabilities. This characteristic permits efficient compounding of operations that can be adapted to specific contexts and subjected to AD, facilitating gradient-based optimization.

Numerical Experiments and Practical Application

The paper describes a numerical experiment grounded in an acoustic optimization problem, involving the Helmholtz equation discretized via Fourier spectral methods. The authors employ this framework to optimize the speed of sound in an acoustic lens, illustrating its competence in handling complex, real-world simulation scenarios. Importantly, the experiment demonstrates automatic differentiation, rapid prototyping benefits, and confirms the framework's practical utility.

Further, the framework's versatility is exemplified through seamless discretization swapping, illustrated by solving a heat equation using both finite differences and Fourier discretization. While boundary conditions are handled implicitly, the framework exhibits potential for absorbing boundary-oriented studies, notably in wave phenomena. The omission of specific boundary condition support invites further enhancement in future iterations of the framework.

Conclusion and Future Directions

The authors graciously offer an open-source repository (https://github.com/ucl-bug/jaxdf) containing this novel framework, inviting contributions and extending its utility across research domains. Future work is expected to incorporate boundary condition implementations and facilitate transformations between discretization types within the same operators, thus broadening the framework's applicability and usability.

Overall, the paper provides a compelling account of a flexible, customizable system for differentiable PDE discretizations in JAX. This work serves as a valuable resource for researchers looking to leverage differentiable simulators in scientific machine learning, contributing to the integration of discrete mathematical operators within machine learning architectures.

Ai Generate Text Spark Streamline Icon: https://streamlinehq.com

Paper Prompts

Sign up for free to create and run prompts on this paper using GPT-5.