Home Page

Authors
Affiliations
J. Emmanuel Johnson
CNRS
MEOM
Takaya Uchida
FSU

A simple package for doing simple differential ocean model approximations in JAX.


Motivation

Other Packages

PhiFlow. This package is an excellent example of how to interface machine learning and. In fact, it is very close to what we want to have in the future. This is a single instantiation of that as their scope is much wider than ours. In addition, we do not offer anything with neural networks but we still try to maintain compatibility.

pyqg / pyqg-jax. This package almost works exclusively with the Quasi-Geostrophic equations. They also only use the spectral decomposition methodologies. We were motivated by community behind this package however we wanted something that was more general that what they offered.

jaxdf. This package was inspirational for defining spatial discretizations.

jax-cfd. This package has been a great inspiration thinking about fluids simulations and machine learning. However, we felt that it was very complex and did not cater to people being able to tinker with it at different levels of granularity. The internals were optimized but very difficult to real which is a barrier for people with less coding experience.

Various QG and SW Model Codes. There are many examples of simplified being implemented in Python. There are some examples for QG models are 1 Layer QG, Spectral QG, and a Stacked QG which are all implemented in PyTorch. However, each of them are blueprints. Actually, I want even more instances of models being implemented. But we wish for some of them to be integrated into the platform so that we can continue to grow as a community and take the latest and the greatest improvements.


Installation

pip

We can directly install it via pip from the

pip install "git+https://github.com/jejjohnson/jaxsw.git"

Cloning

We can also clone the git repository

git clone https://github.com/jejjohnson/jaxsw.git
cd jaxsw

poetry

The easiest way to get started is to simply use the poetry package which installs all necessary dev packages as well

poetry install

pip

We can also install via pip as well

pip install .

Conda

We also have a conda environment with all of the equivalent dependencies.

conda env create -f environments/jax_linux_cpu.yaml
conda activate jaxsw

Tutorials

Anatomy of a PDE

This goes through a full example looking at the components of the jaxsw framework. We explicitly describe the components of the PDE which will be important for this package. This includes:

  1. Domain
  2. Params
  3. Initial Condition Operators
  4. Boundary Operators
  5. Spatial Operators
  6. RHS Operators
  7. Time Steppers

To accomplish this, I will showcase how we can use many other libraries as the backbone to do many canonical PDEs, e.g. FiniteDiffX, , jaxdf, and kernex for spatial discretizations and Diffrax for timestepping schemes. In addition, I will do my best to use some of the best elements of JAX to really take advantage of some of the native elements, e.g. vmap, scan.


3 APIs for PDEs (TODO)

In this 3 part tutorial, we describe three APIs for implementing PDEs:

  • Low-Level/Researcher: We showcase the functional API which offers a high level of granularity and control
  • Mid-Level/Engineer: We showcase the operator API which offers a medium level of granularity and control.
  • High-Level/Casual: We showcase the prebuilt models where we need minimum intervention to get started with PDEs.

Spatial Operators (TODO)

This tutorial goes through how to do some spatial operations using this library. We will look at how we can define a simple geometry and then choose various ways to do operations such as finite difference, e.g. slicing, convolutions, or spectral methods. In addition, we will look at how we can do some simple procedures to get staggered grids which can greatly improve the accuracy of methods.


Grid Operators

This tutorial goes through how we handle grid operations. This is very useful for implementing staggered grids for different fields. We showcase how we can use the grid operators to move between the domains along the staggered grid.


Boundary Operators (TODO)

This tutorial goes through how we handle boundary operations. These are some of the most important components of PDEs as they essentially govern a huge portion of the dynamics and stability within the system. We show case how we can make some custom operators using some of the staple methods like periodic, Dirichlet and Neumann.


TimeSteppers (TODO)

This tutorial goes through how to do time stepping in with JAX. I'll show how this can be accomplished from scratch and through the native JAX package. We also look at the diffrax which allows us to remove a lot of the complexity.


Examples

This are more in-depth tutorials about we can use this package for various canonical PDEs that are necessary for understanding and constructing simple differentiable ocean models.

Lorenz ODEs

I look at the canonical Lorenz ODEs. In particular, we look at the Lorenz-63, the Lorenz-96 and the two level Lorenz-96 ODEs.


12 Steps to Navier-Stokes

In this tutorial, I revisit step-by-step through the original 12 Steps to Navier-Stokes that was created by Lorena Barber. This includes going through elements of typical PDEs such as advection, diffusion and elliptical equations.


Quasi-Geostrophic Equations

In this tutorial, we look at the quasi-geostrophic (QG) equations and demonstrate how we can use elements of this package


Shallow Water Equations

In this tutorial, we look at the shallow water (SW) equations and demonstrate how we can use elements of this package.


Learning (TODO)

In these set of tutorials, we will look at how one can use these differentiable models to do some learning. We will look at parameter estimation, state estimation and the joint bi-level optimization scheme. Some applications will include hybrid models for parameterizations and inverse problems for interpolation schemes.