Parameter Estimation
What components can we use to learn the parameters?
In this section, we will walk through the parameter estimation problem that usually occurs within geoscience problems.
Problem Setup¶
In the parameter estimation case, we are interested in estimating the parameter of a spatial operator. This is a standard problem in machine learning whereby we are given a dataset, , and we are interesting in finding the best model, , that fits the data
In the case of geosciences, we are given a discrete set of observations
where
This observations could be a:
- 1D time series of sea surface temperature values at a particular location
- 2D+T spatiotemporal time series of sea surface height within the gulfstream.
Our task is to find some parameterized model that is able to fit this sequence of observations
Dynamical System¶
For different purposes, we often want to find the best dynamical model that fit this sequence of observations. Fortunately, the entire field of physics is governed by dynamical models where we can draw inspiration from. The typical formulation of a dynamical system which is a description of the spatial dynamics wrt the change in time. We can write it like so:
where is a parameterized ODESolver
which takes an initial state, and outputs the state, , at time, .
is now a parameterized function that transforms the state, , to the observation, .
The time domain, is typically defined on the positive real number line.
For convenience, we often consider it to be bounded between and , i.e., .
Parameter Posterior¶
We will rewrite the posterior given our formulation.
where is the evidence, .
Due to our definition of the Gaussian likelihood as implicitly defined in the observation model (4), we can use the conjugate posterior which would allow for simpler inference. We can write this as
where is the energy function. This shows that the posterior distribution is shown to be proportional to the log-loss of the energy function .
Loss Function¶
We define the solution as the best parameters, , that minimizes the posterior distribution, . However, given equation (6), we can simply minimize the log-loss function .
In this example, we will use the loss function that takes the expectation over the dataset population, . We can write this as
We don't have the true data distribution as we only have some discrete samples along the time domain. So we can empirically approximate this expectation as:
Lastly, from our definition of our dynamical system in equation (4), we can see that the likelihood is given by the dynamical model and the observation model; a composition of the two functions. We can write down the loss function as:
Note: these are very broad assumptions about the data likelihood term. We could introduce more assumptions to account for uncertainty like a prior on the parameters or a diagonal/full covariance matrix for the noise level.
Dynamical Model¶
The dynamical system shown in equation (4) is the corner stone of ODEs and PDEs. It describes the spatiotemporal decomposition of a field.
The spatial operator, , consists of the set of all possible combinations of linear operators and/or non-linear operators. These are typically numerical like finite difference, finite volume, finite element or pseudospectral. The solution to this can be written using the 2nd fundamental theorem of calculus
This equation involves evaluating an integral.
In practice, there are many ways to evaluate this integral numerically.
For example, we could use Taylor expansion which is what Euler's method does or we could use a quadrature method which is what Runge-Kutta methods do.
Regardless of the method chosen, most of the methods do not directly calculate the difference between and especially if the time horizon is very large.
They typically fit them in an "autoregressive" way by incrementally applying the timestepper
recursively from to .
So first, we define the increment operator for the solution to the dynamical system.
Now, we can apply it incrementally.
where we arrive at equation (12)
For the purposes of discussing the parameter estimation problem, we don't need to focus on the underlying method of solving the ODE.
So for the remainder of this note, we will use the symbol to denote the odesolver
operator which takes the initial condition (or multiple initial conditions) and produces the solution, , to the ODE with the spatial operator and the timestepper
at the specified time steps, .
Pseudo-Code¶
Let's initialize all of the pieces that we are going to need from the ODE in equation (11). First, we need to initialize the parameterized spatial operator, .
# initialize inputs
params: PyTree = ...
F: Callable = ...
For this section, we are not concerned with the particular form of the function, because it is not important for this discussion. In the following sections, we will consider what form it will take.
Recall the equation for a single stepper as (13).
We can write some pseudo-code to define our custom TimeStepper
like so:
# initialize integral solver, e.g. Euler, Runga-Kutta, Adam-Bashforth
integral_solver: Callable = ...
def time_stepper(u: Array, params: PyTree, t0: float, t1: float) -> Array:
# calculate the increment (the integral)
u_increment = integral_solver(F, u, params, t0, t1)
# add increment to initial condition
return u + u_increment
Here, we are only calculating the solution to the ODE between and . To calculate the recursive step to calculate the full solution to the ODE from equation (12), we can do it manually by defining a time vector, , with all of the time intervals where we want out output state, .
We can also initialize our state, .
# initialize state
u0: Array = ...
# initialize time steps
time_steps = jnp.arange(0, T, dt)
Now we can apply our time_stepper
function recursively.
u_solutions: List = []
# loop through list of time steps
for t in time_steps:
# time step
u: Array[""] = time_stepper(F, u, t, t+dt, params)
# store the solutions
u_solutions.append(u)
# concatenate the solutions
u_solutions: Array["T-1"] = jnp.stack(u_solutions, axis=0)
However, most modern functions have this functionality built into the software. So we only have to call it on the initial condition.
# initialize time steps
dt = 0.01
# do everything in one shot.
u: Array["T-1"] = package.time_stepper(F, u0, params, t0=0, t1=T, dt=dt)
Tip: Sometimes there is advanced functionality to output the solution at a different time intervals than what we want to march at. For example, we may want to increment at a finer time step but we output at less frequency to match the observations.
# initialize time steps
dt = 0.01
# time steps for saving the output vector
dt_saved = 0.1
saved_time_steps = jnp.arange(0, T, dt_saved)
# do everything in one shot.
u: Array["T-1"] = package.time_stepper(F, u0, params, t0=0, t1=T, dt=dt, saveas=saveas)
Parameter Learning¶
There are many cases where we believe we have a prior belief about the underlying dynamical system that would fit the observations. However, often times there can be unclear parameters within the dynamical model itself. We can use the same learning scheme shown above to try and fit the best parameters, , given the observations, .
Pseudo-Code¶
First, we need to get our dataset of observations.
# get observations
y_obs: Array["T"] = ...
ts: Array["T"] = ...
Then, we need to define our pde and parameters. In this section, we do not need to care explicitly about the PDE we choose. We will outline a few concrete ODEs/PDEs in the next section.
# initialize pde rhs function, e.g. L63, L96, QG
params: PyTree = ...
pde_rhs: Callable = ...
Now, we need to initialize our loss function
# where to save the array
dt = ...
t0, t1 = ts[0], ts[1]
saveas: Array["T-1"] = ts[0]
# define loss function
# initialize loss function
def loss_fn(y: Array, y_hat: Array) --> Array:
return jnp.mean(y_hat - y)
def learning_step(params: PyTree, y_obs: Array) --> Array:
y_hat: Array["T-1"] = dfx.integrate(pde_rhs, params, y_obs[0][""], t0, t1, dt, saveas)
loss: Array[""] = loss_fn(y_obs[1:], y_hat)
return loss
And now, we initialize our optimizer.
# initialize optimizer
learning_rate = 1e-3
optimizer = optax.sgd(learning_rate=learning_rate)
# initialize optimizer state
opt_state = optimizer.init(params)
Now, we can loop through to optimize the parameters.
# loop through epochs
for iepoch in num_epochs:
# calculate gradients wrt params
loss_value: Array[""], grads: PyTree = jax.value_and_grad(learning_step)(params, y_obs)
# update optimizer state
updates, opt_state = optimizer.update(grads, opt_state, params)
# update parameters with new state
params = optax.apply_updates(params, updates)
This can get a little cumbersome, so we can refactor this a bit using more refined APIs.
# initialize the solver
max_iterations = 1_000
solver = jaxopt.LBFGS(fun=learning_step, maxiter=max_iterations)
# run solver
sol: PyTree = solver.run(init_params, y_obs=y_obs)
# extract parameters
new_params: PyTree = sol.params
Example: Lorenz-96¶
We can write the dynamical model for the 2-Level Lorenz 96 equation.
There are a few parameters within this formulation like , , and .
Example: Quasi-Geostrophic Equation¶
There are a few parameters within this formulation which include the Rossby parameter, , the viscosity, , and the linear drag coefficient, .
Pseudo-Code
params: PyTree = ...
forcing_fn: Callable = ...
def qg_equation_of_motion(q, params):
psi = elliptical_inversion(q, beta=params.rossby_radius, method="cg")
u, v = geostrophic.velocities(psi)
rhs_adv = advection_2D(q, u, v)
rhs_beta = geostrophic.beta_plane(q, beta=params.beta)
rhs_diffusion = diffusion_2D(q, viscosity=params.viscosity)
forcing = forcing_fn(q)
return - rhs_adv + rhs_beta + rhs_diffusion + forcing_fn
Hybrid Models¶
First, we need to choose our parameterized spatial operator .
From this formulation, we can consider three types of models that is found within the literature.
Dynamical Model. In this example, and we have a strong assumption about the underlying dynamics that can fit the observations. We do not add any parameterizations. This can be written as a classical dynamical model given as the solution to an ODE or PDE. In the case of PDEs, this can included a model like the QG model or SWM.
Surrogate Model. In this case, and we assume that we have very weak assumptions about the underlying dynamics that can describe the observations. the system dynamics are unknown and we cannot formulate our problem as a PDE.
Hybrid Model. In this case, ad we assume that the system dynamics are partially-known and we can formulate portions of our problem (spatially, temporally, or both) as a PDE and the other portion as a parameterized function.
Note: there is a blurred line between a pure dynamical model and a surrogate model. For example, a parameterized model can come in many forms (see table Table 1 for examples). One could argue that trying to find the parameters to a forcing function that follows a particular form, e.g. linear, periodic, or polynomial, could be considered learning a forcing function.
This formulation is based on the paper [Chen et al., 2021]
Pseudo-Code¶
This pseudo-code will be very similar to the section introducing parameter learning. However, in that section, we did not care about the model and the parameters. However, in this case, we do care about the models and the parameters.
First, we need to define our PDE rhs and the associated parameters.
# initialize pde rhs function, e.g. L63, L96, QG
dyn_params: PyTree = ...
dyn_model_rhs: Callable = ...
Next, we need to define our parameterization. As mentioned above, we have a range of possible choices we can make for the architecture, e.g., linear, basis function, or a neural network.
# initialize neural network model
parameterization_params: PyTree = ...
parameterization_model: Callable = ...
Now, the equation of motion (as shown in equation (18)) will be a combination of the two where they are weighted by a parameter, .
# concat params
params = (pde_params, nn_params)
# create NN function
def equation_of_motion(state: Array[""], params: PyTree, alpha: float=0.5) -> Array[""]:
# unpack the parameters
dyn_params, parameterization_params = params
# dynamical model equation of motion --> Update State
new_state: Array[""] = alpha * dyn_model_rhs(state, dyn_params)
# parameterization --> Correction
correction: Array[""] = (1 - alpha) * parameterization_model(state, parameterization_params)
# update state with correction
new_state: Array[""] += correction
return new_state
The remainder is the exact same training loop that was presented in the earlier pseudo-code section for the parameter learning.
Spatial Parameterization¶
- Denoising, Calibration, Forcing Term
Subgrid Parameterization¶
This example is very similar to the parameterization example that was listed above. However, it is distinct because we are assuming that the missing physics lies in the high resolution simulations.
# define pde model
pde_model: Callable = ...
# define subgrid parameterization term
nn_model: Callable = ...
#
This example was inspired by [Frezat et al. (2022)Ross et al. (2023)Srinivasan et al. (2023)].
Surrogate Models¶
This is known as Neural ODE [Kidger (2022)Chen et al. (2018)] within the literature.
Offline Learning¶
In the above examples, we were using a fully differentiable model to learn the forcing for a dynamical model. So we could simply train the parameterizations on simulation data. We call this offline learning because we are not running any dynamical models. We are simply learning the parameterization with pairwise. Naturally, since we call this offline, then all of the examples above underneath the hybrid modeling section would be considered online learning in some communities.
where comes from pairwise data points from a twin experiment.
Note: This will be orders of magnitude faster because we do not have to go through a full ODESolver function. However, we can imagine there are some downsides to this method. The biggest con is how do we simulate the missing physics that we can expect within the
Pseudo-Code¶
The rest of the code can use the same training loop that we saw in the above section.
Example: Parameterization¶
This parameterization could be classified as a forcing function.
Pseudo-Code¶
# initialize PDEs
dyn_model_params: PyTree = ...
dyn_model_rhs: Callable [[Array["H W"], ...], Array["H W"]] = ...
forcing_fn: Callable = ...
# run a full simulation
dyn_sol_forcing: Array["T H W"] = package.integrate(hires_dyn_model, dyn_model_params, forcing_fn, ...)
dyn_sol: Array["T H W"] = package.integrate(lores_dyn_model, ...)
# create dataset
forcing_err: Array["T H W"] = dyn_sol_forcing - dyn_sol
# initialize parameterization + params
params: PyTree = ...
parameterization_fn: Callable [[Array["H W"], ...], Array["H W"]]= ...
# define loss function
# initialize loss function
def loss_fn(y: Array, y_hat: Array) --> Array:
return jax.sum(jnp.mean(y_hat - y, axis=1))
def learning_step(params: PyTree, dyn_sol: Array, forcing_err: Array) --> Array:
# vectorize the operation over the time dimension
forcing_err_hat: Array["T H W"] = jax.vmap(parameterization_fn)(dyn_sol, params)
# compute loss
loss: Array[""] = loss_fn(forcing_err, forcing_err_hat)
return loss
Example: Subgrid Parameterization¶
Pseudo-Code¶
# initialize PDEs
hires_dyn_model: Callable [[Array["H W"], ...], Array["H W"]] = ...
lores_dyn_model: Callable [[Array["h w"], ...], Array["h w"]]
# run a full simulation
hires_sol: Array["T H W"] = package.integrate(hires_dyn_model, ...)
lores_sol: Array["T h w"] = package.integrate(lores_dyn_model, ...)
# filter & downsample/upscale/coarse-grain
hires_sol_corrupt: Array["T H W"] = filter_fn(hires_sol, ...)
hires_sol_corrupt: Array["T h w"] = downscale_fn(hires_sol_corrupt, ...)
# create dataset
lores_err: Array["T h w"] = hires_sol_corrupt - lores_sol
# initialize parameterization + params
params: PyTree = ...
parameterization_fn: Callable [[Array["h w"], ...], Array["h w"]]= ...
# define loss function
# initialize loss function
def loss_fn(y: Array, y_hat: Array) --> Array:
return jax.sum(jnp.mean(y_hat - y, axis=1))
def learning_step(params: PyTree, lores_sol: Array, lores_err: Array) --> Array:
# vectorize the operation over the time dimension
lores_err_hat: Array["T h w"] = jax.vmap(parameterization_fn)(lores_sol, params)
# compute loss
loss: Array[""] = loss_fn(lores_err, lores_err_hat)
return loss
Example: Surrogate Models¶
In this example, we are going to learn a fully parameterized spatial operator that will map the state from time to . This can be labeled a forecasting problem using a spatial operator that works as an autoregressive function.
# initialize dynamical model
dyn_model_params: PyTree = ...
dyn_model_rhs: Callable [[Array["H W"], ...], Array["H W"]] = ...
# run a full simulation
u_sim: Array["T H W"] = package.integrate(dyn_model_rhs, dyn_model_params, ...)
# initialize spatial operator + params
params: PyTree = ...
spatial_operator: Callable [[Array["H W"], ...], Array["H W"]]= ...
# define learning step
def learning_step(params: PyTree, u_sim: Array) --> Array[""]:
# vectorize the operation over the time dimension (except last)
u_hat: Array["T-1 H W"] = jax.vmap(spatial_operator)(u_sim[:-1], params)
# compute loss
loss: Array[""] = loss_fn(u_sim[1:], u_hat)
return loss
Model Uncertainty¶
We can take a completely probabilistic approach to this
There are parallels to some algorithms which are nonlinear extensions to the Kalman Filter, e.g., Extended Kalman Filter (EKF), Unscented Kalman Filter (UKF), and the Assumed Density Filter (ADF). In addition, there are also parallels to the Ensemble Kalman Filter (EnsKF).
There are also connections to methods that try to learn a reduced order model (ROM), i.e., a transformation from the state space, , to a latent representation, , where . This has connections to Koopman theory [Brunton et al., 2021] which postulates that there exists some non-linear transformation whereby the underlying dynamics are linear. There are some methods which try to directly learn a linear reduced order space like Dynamic Mode Decomposition (DMD) [Tu et al. (2014)Schmid (2022)] or operator inference [Qian et al., 2021]. These linear approximations can easily be plugged into the Kalman Filter framework to account for some uncertainty. There are similar methods in the machine learning community which directly try to learn the transformation via flow-like models, e.g., the Kalman variational autoencoder [Gunnarsson et al., 2022] or the normalizing Kalman Filter. The paper on dynamical variational autoencoders [Girin et al., 2021] is a great review on the family of methods available.
- Chen, Y., Sanz-Alonso, D., & Willett, R. (2021). Auto-differentiable Ensemble Kalman Filters. arXiv. 10.48550/ARXIV.2107.07687
- Frezat, H., Sommer, J. L., Fablet, R., Balarac, G., & Lguensat, R. (2022). A Posteriori Learning for Quasi-Geostrophic Turbulence Parametrization. Journal of Advances in Modeling Earth Systems, 14(11). 10.1029/2022ms003124
- Ross, A., Li, Z., Perezhogin, P., Fernandez-Granda, C., & Zanna, L. (2023). Benchmarking of Machine Learning Ocean Subgrid Parameterizations in an Idealized Model. Journal of Advances in Modeling Earth Systems, 15(1). 10.1029/2022ms003258
- Srinivasan, K., Chekroun, M. D., & McWilliams, J. C. (2023). Turbulence closure with small, local neural networks: Forced two-dimensional and β-plane flows. arXiv. 10.48550/ARXIV.2304.05029
- Kidger, P. (2022). On Neural Differential Equations. arXiv. 10.48550/ARXIV.2202.02435