Bayesian Inference & Ensembles¶
Layer 3 recipes for conjugate updates, second-order variational steps, and ensemble data assimilation. All covariances are operators, so the updates inherit structured solves; all stochastic routines take explicit PRNG keys.
Bayesian linear regression¶
Closed-form Gaussian posterior updates — full covariance or diagonal-only — plus the marginal likelihood and expected log-likelihood that score them.
Structured linear algebra and Gaussian primitives for JAX.
blr_full_update(nat1: Float[Array, ' d'], nat2: Float[Array, 'd d'], grad: Float[Array, ' d'], hessian: Float[Array, 'd d'], lr: float, *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, ' d'], Float[Array, 'd d']]
¶
Full-rank natural parameter BLR update step.
Computes the damped update for full-rank variational parameters:
nat2_{new} = (1 - lr) \cdot nat2 + lr \cdot (-\tfrac{1}{2}(-H))
\mu = solve(-2 \cdot nat2, nat1)
nat1_{new} = (1 - lr) \cdot nat1 + lr \cdot (grad - H \mu)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
nat1
|
Float[Array, ' d']
|
Current natural location, shape |
required |
nat2
|
Float[Array, 'd d']
|
Current natural precision matrix (eta2), shape |
required |
grad
|
Float[Array, ' d']
|
Gradient of log-likelihood, shape |
required |
hessian
|
Float[Array, 'd d']
|
Hessian of log-likelihood (negative for log-concave),
shape |
required |
lr
|
float
|
Learning rate / damping factor. |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for structured linear algebra.
When |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' d'], Float[Array, 'd d']]
|
Tuple |
Source code in src/gaussx/_inference/_blr.py
blr_diag_update(nat1: Float[Array, ' d'], nat2_diag: Float[Array, ' d'], grad: Float[Array, ' d'], hessian_diag: Float[Array, ' d'], lr: float) -> tuple[Float[Array, ' d'], Float[Array, ' d']]
¶
Diagonal natural parameter BLR update step.
Computes the damped update for diagonal variational parameters:
\mu = nat1 / (-2 \cdot nat2)
eta2_{target} = -\tfrac{1}{2}(-hessian\_diag) = 0.5 \cdot hessian\_diag
eta1_{target} = grad - hessian\_diag \cdot \mu
nat1_{new} = (1 - lr) \cdot nat1 + lr \cdot eta1_{target}
nat2_{new} = (1 - lr) \cdot nat2 + lr \cdot eta2_{target}
where nat2 (eta2) stores -\tfrac{1}{2} \lambda with
\lambda = -hessian\_diag (diagonal precision).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
nat1
|
Float[Array, ' d']
|
Current natural location, shape |
required |
nat2_diag
|
Float[Array, ' d']
|
Current diagonal natural precision (eta2), shape |
required |
grad
|
Float[Array, ' d']
|
Gradient of log-likelihood, shape |
required |
hessian_diag
|
Float[Array, ' d']
|
Diagonal of Hessian (negative for log-concave),
shape |
required |
lr
|
float
|
Learning rate / damping factor. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' d'], Float[Array, ' d']]
|
Tuple |
Source code in src/gaussx/_inference/_blr.py
log_marginal_likelihood(loc: Float[Array, ' N'], cov_operator: lx.AbstractLinearOperator, y: Float[Array, ' N'], *, solver: AbstractSolverStrategy | None = None) -> Float[Array, '']
¶
GP log marginal likelihood.
Computes:
log p(y) = -0.5 * (y-mu)^T K^{-1} (y-mu) - 0.5 * log|K| - N/2 * log(2pi)
Delegates to gaussx.gaussian_log_prob.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loc
|
Float[Array, ' N']
|
Prior mean, shape |
required |
cov_operator
|
AbstractLinearOperator
|
Covariance operator K, shape |
required |
y
|
Float[Array, ' N']
|
Observations, shape |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy. When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar log marginal likelihood. |
Source code in src/gaussx/_inference/_inference.py
gaussian_expected_log_lik(y: Float[Array, ' N'], q_mu: Float[Array, ' N'], q_cov: lx.AbstractLinearOperator, noise: lx.AbstractLinearOperator, *, solver: AbstractSolverStrategy | None = None) -> Float[Array, '']
¶
Expected log-likelihood E_q[log N(y | f, R)].
Computes:
E_q[log N(y|f,R)] = log N(y | q_mu, R) - 0.5 * tr(R^{-1} q_cov)
Core to variational inference ELBO computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y
|
Float[Array, ' N']
|
Observations, shape |
required |
q_mu
|
Float[Array, ' N']
|
Variational mean, shape |
required |
q_cov
|
AbstractLinearOperator
|
Variational covariance operator, shape |
required |
noise
|
AbstractLinearOperator
|
Noise covariance operator R, shape |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy. When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar expected log-likelihood. |
Source code in src/gaussx/_inference/_inference.py
Newton & natural-gradient updates¶
Second-order variational steps: Newton's method on the variational objective, Gauss-Newton curvature (exact diagonal or Hutchinson-estimated), damped natural-gradient steps, and the PSD projection that keeps Riemannian updates on the manifold.
Structured linear algebra and Gaussian primitives for JAX.
newton_update(mean: Float[Array, ' N'], jacobian: Float[Array, ' N'], hessian: Float[Array, 'N N']) -> tuple[Float[Array, ' N'], Float[Array, 'N N']]
¶
Convert a Newton step to natural pseudo-likelihood parameters.
Computes:
nat1 = jacobian - hessian @ mean
nat2 = -hessian
Used in Laplace/Newton-based approximate inference to convert function-space derivatives into site natural parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mean
|
Float[Array, ' N']
|
Current mean, shape |
required |
jacobian
|
Float[Array, ' N']
|
First derivative of log-likelihood, shape |
required |
hessian
|
Float[Array, 'N N']
|
Second derivative (negative definite), shape |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' N'], Float[Array, 'N N']]
|
Tuple |
Source code in src/gaussx/_inference/_inference.py
damped_natural_update(nat1_old: Float[Array, ' d'], nat2_old: lx.AbstractLinearOperator | Float[Array, 'd d'], nat1_target: Float[Array, ' d'], nat2_target: lx.AbstractLinearOperator | Float[Array, 'd d'], lr: float = 1.0) -> tuple[Float[Array, ' d'], lx.AbstractLinearOperator | Float[Array, 'd d']]
¶
Damped update in natural parameter space.
The universal primitive for iterative approximate inference (EP, VI, Newton, PL). Every method reduces to computing target natural parameters and applying this damped update:
nat1_{new} = (1 - lr) \cdot nat1_{old} + lr \cdot nat1_{target}
nat2_{new} = (1 - lr) \cdot nat2_{old} + lr \cdot nat2_{target}
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
nat1_old
|
Float[Array, ' d']
|
Current natural location parameter. |
required |
nat2_old
|
AbstractLinearOperator | Float[Array, 'd d']
|
Current natural precision-like parameter.
Can be an array, |
required |
nat1_target
|
Float[Array, ' d']
|
Target natural location parameter. |
required |
nat2_target
|
AbstractLinearOperator | Float[Array, 'd d']
|
Target natural precision-like parameter. |
required |
lr
|
float
|
Learning rate / damping factor. |
1.0
|
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' d'], AbstractLinearOperator | Float[Array, 'd d']]
|
Tuple |
Source code in src/gaussx/_inference/_natural_gradient.py
gauss_newton_precision(jacobian: Float[Array, 'D_obs D_latent']) -> lx.AbstractLinearOperator
¶
Gauss-Newton precision matrix J^T J.
For likelihoods with residual structure r(f), the Gauss-Newton
Hessian approximation is -J_r^T J_r which gives precision
\Lambda = J^T J (always PSD).
When D_{obs} < D_{latent}, returns a LowRankUpdate
to enable efficient Woodbury-based solves downstream.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
jacobian
|
Float[Array, 'D_obs D_latent']
|
Jacobian of the residual, shape |
required |
Returns:
| Type | Description |
|---|---|
AbstractLinearOperator
|
PSD precision operator of shape |
Source code in src/gaussx/_inference/_natural_gradient.py
ggn_diagonal(jacobian: Float[Array, 'N d']) -> Float[Array, ' d']
¶
Generalized Gauss-Newton diagonal approximation.
Computes \mathrm{diag}(J^T J) = \sum_i J_{i,:}^2, the diagonal
of the Gauss-Newton Hessian approximation. Always non-negative,
guaranteeing PSD precision updates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
jacobian
|
Float[Array, 'N d']
|
Jacobian matrix, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' d']
|
Diagonal of |
Source code in src/gaussx/_inference/_blr.py
hutchinson_hessian_diag(hvp_fn: Callable[[Float[Array, ' d']], Float[Array, ' d']], key: jax.Array, d: int, n_samples: int = 1, dtype: DTypeLike | None = None) -> Float[Array, ' d']
¶
Stochastic Hessian diagonal via Hutchinson with Rademacher probes.
Estimates \mathrm{diag}(H) using the identity
\mathrm{diag}(H) = E[z \odot (H z)] where z is a
Rademacher random vector (entries \pm 1 with equal probability).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hvp_fn
|
Callable[[Float[Array, ' d']], Float[Array, ' d']]
|
Hessian-vector product function |
required |
key
|
Array
|
PRNG key for random probe generation. |
required |
d
|
int
|
Dimension of the Hessian. |
required |
n_samples
|
int
|
Number of random probes. More samples give better
estimates. Default |
1
|
dtype
|
DTypeLike | None
|
Floating-point dtype for the Rademacher probes. Defaults to the current JAX default floating dtype. |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' d']
|
Estimated diagonal of the Hessian, shape |
Source code in src/gaussx/_inference/_blr.py
riemannian_psd_correction(hessian: Float[Array, 'd d'], site_precision: Float[Array, 'd d'], site_covariance: Float[Array, 'd d'], lr: float = 1.0) -> Float[Array, 'd d']
¶
Riemannian gradient correction for PSD precision updates.
Ensures the corrected Hessian remains negative semi-definite, stabilizing Newton/EP/VI when the raw Hessian is indefinite:
G = site\_precision + hessian
H_{psd} = hessian - 0.5 \cdot lr \cdot G \cdot S \cdot G
where S is the site covariance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hessian
|
Float[Array, 'd d']
|
Raw second derivative, shape |
required |
site_precision
|
Float[Array, 'd d']
|
Current site precision, shape |
required |
site_covariance
|
Float[Array, 'd d']
|
Current site covariance, shape |
required |
lr
|
float
|
Learning rate. Default |
1.0
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'd d']
|
Corrected Hessian, shape |
Source code in src/gaussx/_inference/_natural_gradient.py
cavity_distribution(post_mean: Float[Array, ' N'], post_cov: lx.AbstractLinearOperator, site_nat1: Float[Array, ' N'], site_nat2: lx.AbstractLinearOperator, power: float = 1.0) -> tuple[Float[Array, ' N'], lx.AbstractLinearOperator]
¶
Compute EP cavity distribution by removing a site.
Computes:
cav_prec = post_prec - power * site_nat2
cav_cov = inv(cav_prec)
cav_mean = cav_cov @ (post_prec @ post_mean - power * site_nat1)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
post_mean
|
Float[Array, ' N']
|
Posterior mean, shape |
required |
post_cov
|
AbstractLinearOperator
|
Posterior covariance operator. |
required |
site_nat1
|
Float[Array, ' N']
|
Site natural parameter (precision-weighted mean). |
required |
site_nat2
|
AbstractLinearOperator
|
Site natural parameter (precision). |
required |
power
|
float
|
Power EP fraction (default 1.0 for standard EP). |
1.0
|
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' N'], AbstractLinearOperator]
|
Tuple |
Source code in src/gaussx/_inference/_inference.py
trace_correction(K_xx: lx.AbstractLinearOperator, K_xz: Float[Array, 'N M'], K_zz: lx.AbstractLinearOperator, *, solver: AbstractSolveStrategy | None = None) -> Float[Array, '']
¶
Trace term in Titsias collapsed ELBO.
Computes:
tr(K_xx) - tr(K_xz^T K_zz^{-1} K_xz)
This is the "trace correction" that penalizes the Nystrom approximation error.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
K_xx
|
AbstractLinearOperator
|
Full covariance, shape |
required |
K_xz
|
Float[Array, 'N M']
|
Cross-covariance, shape |
required |
K_zz
|
AbstractLinearOperator
|
Inducing covariance, shape |
required |
solver
|
AbstractSolveStrategy | None
|
Optional solve strategy. When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar trace correction. |
Source code in src/gaussx/_inference/_inference.py
Ensemble covariances & Kalman gain¶
Bessel-corrected empirical (cross-)covariances from ensemble members and the ensemble Kalman gain built from them.
Structured linear algebra and Gaussian primitives for JAX.
ensemble_covariance(particles: Float[Array, 'J N'], *, bessel: bool = False) -> LowRankUpdate
¶
Empirical covariance from an ensemble as a low-rank operator.
Returns C = c X'^T X' with c = 1 / J when bessel=False
(default, maximum likelihood) and c = 1 / (J - 1) when
bessel=True (unbiased / ensemble Kalman filter convention).
The result is a LowRankUpdate of rank <= J-1 rather than
materializing the full (N, N) matrix. Efficient when
J << N.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
particles
|
Float[Array, 'J N']
|
Ensemble of shape |
required |
bessel
|
bool
|
If True, apply the |
False
|
Returns:
| Type | Description |
|---|---|
LowRankUpdate
|
A |
LowRankUpdate
|
covariance, with a zero base and |
Source code in src/gaussx/_inference/_ensemble.py
ensemble_cross_covariance(particles_theta: Float[Array, 'J N'], particles_G: Float[Array, 'J M'], *, bessel: bool = False) -> Float[Array, 'N M']
¶
Cross-covariance between two ensemble sets.
Computes C^{theta,G} = c sum_j (theta_j - bar)(G_j - bar)^T
with c = 1 / J by default or c = 1 / (J - 1) when
bessel=True.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
particles_theta
|
Float[Array, 'J N']
|
First ensemble, shape |
required |
particles_G
|
Float[Array, 'J M']
|
Second ensemble, shape |
required |
bessel
|
bool
|
If True, apply the |
False
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'N M']
|
Cross-covariance array of shape |
Source code in src/gaussx/_inference/_ensemble.py
ensemble_kalman_gain(particles: Float[Array, 'J N'], obs_particles: Float[Array, 'J M'], obs_noise: lx.AbstractLinearOperator, *, solver: AbstractSolverStrategy | None = None, bessel: bool = True) -> Float[Array, 'N M']
¶
Kalman gain from an ensemble and its image in observation space.
Computes K = C^{xH} (C^{HH} + R)^{-1}, where C^{xH} is the
state-observation cross-covariance and C^{HH} is the
observation-space ensemble covariance. The innovation covariance
S = C^{HH} + R is assembled as a LowRankUpdate so
solve_rows can use structural dispatch via the Woodbury identity.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
particles
|
Float[Array, 'J N']
|
Prior ensemble in state space, shape |
required |
obs_particles
|
Float[Array, 'J M']
|
Prior ensemble in observation space, shape |
required |
obs_noise
|
AbstractLinearOperator
|
Observation error covariance operator, shape |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy. |
None
|
bessel
|
bool
|
Defaults to True, unlike the lower-level covariance helpers,
because this recipe follows the unbiased EnKF convention. Use
False for maximum-likelihood recipes with a |
True
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'N M']
|
Dense Kalman gain of shape |
Source code in src/gaussx/_inference/_ensemble.py
etkf_transform(obs_particles: Float[Array, 'J M'], y: Float[Array, ' M'], obs_noise: lx.AbstractLinearOperator, *, inflation: float = 1.0) -> tuple[Float[Array, ' J'], Float[Array, 'J J']]
¶
Ensemble Transform Kalman Filter (ETKF) analysis weights.
Deterministic (perturbed-obs-free) ensemble square-root analysis in the
J-dimensional ensemble space (Bishop et al. 2001; Hunt et al. 2007).
With raw observation perturbations Y = H X'^f (columns are members) and
d = y - H x_bar^f,
where lambda is the (multiplicative) inflation and W is the
symmetric square root. The analysis ensemble is reconstructed as
The symmetric (eigendecomposition) square root -- not a Cholesky factor --
is required: because the observation perturbations are zero-mean, 1 is
an eigenvector of W with eigenvalue 1, which makes the transform
exactly mean-preserving (sum_j X'^a_j = 0).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs_particles
|
Float[Array, 'J M']
|
Forecast ensemble in observation space, shape |
required |
y
|
Float[Array, ' M']
|
Observation vector, shape |
required |
obs_noise
|
AbstractLinearOperator
|
Observation error covariance operator |
required |
inflation
|
float
|
Multiplicative covariance inflation |
1.0
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' J']
|
|
Float[Array, 'J J']
|
|
tuple[Float[Array, ' J'], Float[Array, 'J J']]
|
perturbations |
tuple[Float[Array, ' J'], Float[Array, 'J J']]
|
|
Source code in src/gaussx/_inference/_ensemble.py
Localization & inflation¶
The standard fixes for small-ensemble rank deficiency: Schur-product localization with a taper (Gaspari-Cohn by default) and multiplicative / RTPP / RTPS inflation.
Structured linear algebra and Gaussian primitives for JAX.
localization_matrix(coords_a: Float[Array, 'Na D'], coords_b: Float[Array, 'Nb D'], c: float, metric: Callable[[Float[Array, 'Na D'], Float[Array, 'Nb D']], Float[Array, 'Na Nb']] = euclidean_distance) -> Float[Array, 'Na Nb']
¶
Pairwise Gaspari-Cohn taper rho(dist(a_i, b_j); c).
Use this to build the rho_xy (state-obs) and rho_yy (obs-obs)
localization matrices consumed by localized_kalman_gain.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coords_a
|
Float[Array, 'Na D']
|
First set of points, shape |
required |
coords_b
|
Float[Array, 'Nb D']
|
Second set of points, shape |
required |
c
|
float
|
Gaspari-Cohn compact-support radius. |
required |
metric
|
Callable[[Float[Array, 'Na D'], Float[Array, 'Nb D']], Float[Array, 'Na Nb']]
|
Pairwise distance function returning an |
euclidean_distance
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'Na Nb']
|
Localization matrix of shape |
Source code in src/gaussx/_inference/_ensemble.py
localized_kalman_gain(particles: Float[Array, 'J N'], obs_particles: Float[Array, 'J M'], obs_noise: lx.AbstractLinearOperator, rho_xy: Float[Array, 'N M'], rho_yy: Float[Array, 'M M'], *, solver: AbstractSolverStrategy | None = None, bessel: bool = True) -> Float[Array, 'N M']
¶
Ensemble Kalman gain with Hadamard (Schur-product) localization.
Computes
where P_xy is the state-observation cross-covariance and P_yy the
observation-space ensemble covariance. Tapering kills spurious long-range
sample correlations; because Gaspari-Cohn is positive-definite, the Schur
product theorem keeps rho_yy . P_yy PSD, so the innovation covariance
stays invertible.
This is the localized counterpart of ensemble_kalman_gain. Unlike
that routine, the Hadamard product destroys the low-rank structure, so the
innovation covariance is materialized densely and the solve is
O(N M + M^3). Recover the unlocalized gain as the c -> inf limit
(rho_xy = rho_yy = 1).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
particles
|
Float[Array, 'J N']
|
Prior ensemble in state space, shape |
required |
obs_particles
|
Float[Array, 'J M']
|
Prior ensemble in observation space, shape |
required |
obs_noise
|
AbstractLinearOperator
|
Observation error covariance operator |
required |
rho_xy
|
Float[Array, 'N M']
|
State-observation localization matrix, shape |
required |
rho_yy
|
Float[Array, 'M M']
|
Observation-observation localization matrix, shape |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for the dense innovation solve. |
None
|
bessel
|
bool
|
Use the |
True
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'N M']
|
Dense localized Kalman gain of shape |
Source code in src/gaussx/_inference/_ensemble.py
gaspari_cohn(r: Float[Array, '*shape'], c: float) -> Float[Array, '*shape']
¶
Gaspari-Cohn (1999) fifth-order compactly-supported taper.
The standard positive-definite, approximately-Gaussian localization
function. With z = 2 |r| / c it is the piecewise-rational
so rho(0) = 1 and rho = 0 for |r| >= c (c is the
compact-support radius, not a Gaussian length scale). The taper is
only \(C^1\) at the knots z = 1, 2.
Differentiability: the 2 / (3 z) term in the middle branch is guarded
with a safe denominator so reverse-mode gradients are finite at r = 0
(which would otherwise produce NaN via the standard where pitfall).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
r
|
Float[Array, '*shape']
|
Distances (any shape), e.g. a pairwise distance matrix. |
required |
c
|
float
|
Compact-support radius; |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '*shape']
|
Taper values in |
Source code in src/gaussx/_inference/_ensemble.py
inflate_multiplicative(ensemble: Float[Array, 'J N'], factor: float) -> Float[Array, 'J N']
¶
Multiplicative ensemble inflation about the mean.
Restores ensemble spread lost to sampling error / model collapse by scaling
perturbations: x_j <- x_bar + factor (x_j - x_bar).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ensemble
|
Float[Array, 'J N']
|
Ensemble of shape |
required |
factor
|
float
|
Inflation factor |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'J N']
|
Inflated ensemble, shape |
Source code in src/gaussx/_inference/_ensemble.py
inflate_rtpp(posterior: Float[Array, 'J N'], prior: Float[Array, 'J N'], alpha: float) -> Float[Array, 'J N']
¶
Relaxation to prior perturbations (RTPP; Zhang et al. 2004).
Relaxes posterior perturbations toward the prior perturbations while keeping
the posterior mean: x'^a <- (1 - alpha) x'^a + alpha x'^f, where the
perturbations are taken about each ensemble's own mean.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
posterior
|
Float[Array, 'J N']
|
Analysis ensemble, shape |
required |
prior
|
Float[Array, 'J N']
|
Forecast ensemble, shape |
required |
alpha
|
float
|
Relaxation weight in |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'J N']
|
Relaxed analysis ensemble, shape |
Float[Array, 'J N']
|
preserved. |
Source code in src/gaussx/_inference/_ensemble.py
inflate_rtps(posterior: Float[Array, 'J N'], prior: Float[Array, 'J N'], beta: float, eps: float = 1e-12) -> Float[Array, 'J N']
¶
Relaxation to prior spread (RTPS; Whitaker & Hamill 2012).
Scales each posterior perturbation, per coordinate, so the analysis spread
relaxes back toward the prior spread:
x'^a <- x'^a [ (1 - beta) + beta sigma^f / sigma^a ], with sigma the
per-coordinate ensemble standard deviation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
posterior
|
Float[Array, 'J N']
|
Analysis ensemble, shape |
required |
prior
|
Float[Array, 'J N']
|
Forecast ensemble, shape |
required |
beta
|
float
|
Relaxation weight in |
required |
eps
|
float
|
Floor on the posterior std to avoid division by zero. |
1e-12
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'J N']
|
Spread-restored analysis ensemble, shape |
Float[Array, 'J N']
|
is preserved. |
Source code in src/gaussx/_inference/_ensemble.py
Distances¶
Structured linear algebra and Gaussian primitives for JAX.
euclidean_distance(coords_a: Float[Array, 'Na D'], coords_b: Float[Array, 'Nb D']) -> Float[Array, 'Na Nb']
¶
Pairwise Euclidean distances ||a_i - b_j||.
A default metric for localization_matrix. Builds on
stable_squared_distances and takes a gradient-safe square root so
zero distances (e.g. the diagonal of a self-distance matrix) do not produce
NaN gradients.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coords_a
|
Float[Array, 'Na D']
|
First set of points, shape |
required |
coords_b
|
Float[Array, 'Nb D']
|
Second set of points, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'Na Nb']
|
Distance matrix of shape |
Source code in src/gaussx/_inference/_ensemble.py
haversine_distance(coords_a: Float[Array, 'Na 2'], coords_b: Float[Array, 'Nb 2'], radius: float = 6371000.0) -> Float[Array, 'Na Nb']
¶
Pairwise great-circle (haversine) distances on a sphere.
A metric for localization_matrix on geophysical grids.
Coordinates are (latitude, longitude) in radians.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coords_a
|
Float[Array, 'Na 2']
|
First set of points |
required |
coords_b
|
Float[Array, 'Nb 2']
|
Second set of points |
required |
radius
|
float
|
Sphere radius in the units of the returned distance (default the
Earth mean radius, |
6371000.0
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'Na Nb']
|
Great-circle distance matrix of shape |