Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Uncertainty Propagation through Nonlinear Functions

When a Gaussian random variable xN(μ,Σ)x \sim \mathcal{N}(\mu, \Sigma) passes through a nonlinear function ff, the output y=f(x)y = f(x) is generally not Gaussian. Different methods approximate the output distribution p(y)p(y) with varying fidelity.

Motivation: uncertain inputs in Gaussian processes

The problem of propagating uncertainty through nonlinear maps arises naturally in Gaussian process (GP) models with uncertain inputs. At test time the input xx_* may itself be uncertain — for example:

  • In multi-step-ahead prediction with GP dynamics models (Girard et al., 2003; Deisenroth & Rasmussen, 2011), the predicted state at time tt becomes the uncertain input at time t+1t{+}1.
  • In GP-based policy search (PILCO; Deisenroth & Rasmussen, 2011), the controller must evaluate a GP whose input is a belief state.
  • In heteroscedastic and input-noise models (McHutchon & Rasmussen, 2011), observation noise in the inputs induces effective output noise that depends on the local gradient of the GP posterior mean.

Given a GP posterior fDf \mid \mathcal{D} and an input distribution xN(μ,Σ)x_* \sim \mathcal{N}(\mu_*, \Sigma_*), we seek the predictive moments

E[f(x)]=f(x)N(xμ,Σ)dx,\mathbb{E}[f(x_*)] = \int f(x)\,\mathcal{N}(x \mid \mu_*, \Sigma_*)\,dx,
Var[f(x)]=(f(x)E[f(x)])2N(xμ,Σ)dx.\operatorname{Var}[f(x_*)] = \int \bigl(f(x) - \mathbb{E}[f(x_*)]\bigr)^2 \,\mathcal{N}(x \mid \mu_*, \Sigma_*)\,dx.

For the squared-exponential kernel these integrals have analytic forms (Quinonero-Candela et al., 2003), but for general kernels or composed models they are intractable — and the methods compared below become essential.

Methods compared

This notebook compares six approaches available in gaussx:

MethodPointsKey idea
Taylor 1st1 (mean only)Linearize ff at the mean
Taylor 2nd1 + HessianAdd curvature correction (mean only)
Taylor 2nd (var)1 + HessianCurvature correction (mean + variance)
Unscented2N+12N+1 sigma pointsDeterministic quadrature
Monte CarloSS random samplesEmpirical moments
Pure MCSS random samplesDirect histogram (ground truth)

We propagate a 1-D Gaussian through a deliberately wiggly function to expose how each method handles nonlinearity.

from __future__ import annotations

import warnings


warnings.filterwarnings("ignore", message=r".*IProgress.*")

import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import matplotlib.pyplot as plt

import gaussx


jax.config.update("jax_enable_x64", True)

The nonlinear function

We pick a function with an inflection point, asymmetry, and saturation -- features that challenge linearisation-based methods:

f(x)=sin(1.5x)exp(0.15x2)+0.3xf(x) = \sin(1.5\,x) \cdot \exp(-0.15\,x^2) + 0.3\,x
def f_nonlinear(x):
    """Wiggly nonlinear scalar function."""
    return jnp.sin(1.5 * x) * jnp.exp(-0.15 * x**2) + 0.3 * x


# Vectorised for plotting
f_vec = jax.vmap(f_nonlinear)

x_grid = jnp.linspace(-5, 5, 500)
y_grid = f_vec(x_grid)
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(x_grid, y_grid, "k-", lw=2)
ax.set_xlabel("x")
ax.set_ylabel("f(x)")
ax.set_title("Nonlinear function")
ax.axhline(0, c="grey", lw=0.5)
ax.grid(True, which="major", alpha=0.3)
ax.grid(True, which="minor", alpha=0.1)
ax.minorticks_on()
plt.tight_layout()
plt.show()
<Figure size 1000x400 with 1 Axes>

Input distribution

A 1-D Gaussian centred right where the function has strong curvature:

xN(μ=1.0,  σ2=0.6)x \sim \mathcal{N}(\mu=1.0,\; \sigma^2=0.6)
mu_x = 1.0
sigma2_x = 0.6

state = gaussx.GaussianState(
    mean=jnp.array([mu_x]),
    cov=lx.MatrixLinearOperator(jnp.array([[sigma2_x]]), lx.positive_semidefinite_tag),
)

Mathematical background

All methods below approximate the first two moments of y=f(x)y = f(x) when xN(μ,Σ)x \sim \mathcal{N}(\mu, \Sigma).

Taylor 1st order (linearisation / EKF)

Expand ff to first order around the mean:

f(x)f(μ)+J(xμ),J=fxx=μ.f(x) \approx f(\mu) + J\,(x - \mu), \qquad J = \left.\frac{\partial f}{\partial x}\right|_{x=\mu}.

Under this linear approximation the output is exactly Gaussian:

μy=f(μ),Σy=JΣJ.\mu_y = f(\mu), \qquad \Sigma_y = J\,\Sigma\,J^\top.

This is the update used in the extended Kalman filter (EKF) (Anderson & Moore, 1979). It is exact when ff is affine, but ignores curvature — the mean estimate is biased whenever f(μ)0f''(\mu) \neq 0.

Taylor 2nd order

Including the second-order term of the Taylor expansion gives a correction to the mean via the Hessian Hi=2fi(μ)H_i = \nabla^2 f_i(\mu):

μy,i=fi(μ)+12tr(HiΣ).\mu_{y,i} = f_i(\mu) + \tfrac{1}{2}\operatorname{tr}(H_i\,\Sigma).

By default, the covariance remains the first-order expression Σy=JΣJ\Sigma_y = J\Sigma J^\top — this “second-order EKF” (mean-only correction) appears in Athans et al. (1968).

Optionally, with correct_variance=True, the covariance also receives a second-order correction using fourth Gaussian moments:

(Σy)ij+=12tr(HiΣHjΣ).(\Sigma_y)_{ij} \mathrel{+}= \tfrac{1}{2}\operatorname{tr}(H_i\,\Sigma\,H_j\,\Sigma).

This can improve accuracy for mildly nonlinear functions but may overshoot for strongly nonlinear ones.

Unscented transform (UT)

Instead of differentiating ff, the UT (Julier & Uhlmann, 1997) evaluates it at a minimal set of sigma points {Xi}i=02N\{\mathcal{X}_i\}_{i=0}^{2N} chosen to match the mean and covariance of xx exactly:

X0=μ,Xi=μ+((N+λ)Σ)i,XN+i=μ((N+λ)Σ)i,\mathcal{X}_0 = \mu, \qquad \mathcal{X}_i = \mu + \bigl(\sqrt{(N+\lambda)\,\Sigma}\bigr)_i, \qquad \mathcal{X}_{N+i} = \mu - \bigl(\sqrt{(N+\lambda)\,\Sigma}\bigr)_i,

where λ=α2(N+κ)N\lambda = \alpha^2(N + \kappa) - N and (M)i(\sqrt{M})_i denotes the ii-th column of the matrix square root. The output moments are weighted sums:

μy=iwi(m)f(Xi),Σy=iwi(c)(f(Xi)μy)(f(Xi)μy).\mu_y = \sum_i w_i^{(m)}\,f(\mathcal{X}_i), \qquad \Sigma_y = \sum_i w_i^{(c)}\bigl(f(\mathcal{X}_i) - \mu_y\bigr) \bigl(f(\mathcal{X}_i) - \mu_y\bigr)^\top.

The UT captures nonlinear effects up to third order for Gaussian inputs (Julier & Uhlmann, 2004) and requires no derivatives of ff. It underpins the unscented Kalman filter (UKF) (Wan & van der Merwe, 2000).

Monte Carlo moment matching

Draw SS samples x(s)N(μ,Σ)x^{(s)} \sim \mathcal{N}(\mu, \Sigma), push them through ff, and compute empirical moments:

μy1Ss=1Sf(x(s)),Σy1Ss=1S(f(x(s))μy)(f(x(s))μy).\mu_y \approx \frac{1}{S}\sum_{s=1}^S f(x^{(s)}), \qquad \Sigma_y \approx \frac{1}{S}\sum_{s=1}^S \bigl(f(x^{(s)}) - \mu_y\bigr)\bigl(f(x^{(s)}) - \mu_y\bigr)^\top.

This is unbiased and converges at rate O(1/S)O(1/\sqrt{S}) regardless of dimension — making it the method of choice when NN is large or ff is very nonlinear. The moment-matched Gaussian is the best Gaussian approximation in the KL sense and is used for “assumed density filtering” in GP dynamics models (Deisenroth & Rasmussen, 2011).

Run all six methods

# --- Taylor 1st order (EKF) ---
taylor1 = gaussx.TaylorIntegrator(order=1)
res_t1 = taylor1.integrate(lambda x: jnp.atleast_1d(f_nonlinear(x[0])), state)

# --- Taylor 2nd order (mean correction only, default) ---
taylor2 = gaussx.TaylorIntegrator(order=2)
res_t2 = taylor2.integrate(lambda x: jnp.atleast_1d(f_nonlinear(x[0])), state)

# --- Taylor 2nd order (mean + variance correction) ---
taylor2v = gaussx.TaylorIntegrator(order=2, correct_variance=True)
res_t2v = taylor2v.integrate(lambda x: jnp.atleast_1d(f_nonlinear(x[0])), state)

# --- Unscented transform ---
ut = gaussx.UnscentedIntegrator(alpha=1.0, beta=2.0, kappa=0.0)
res_ut = ut.integrate(lambda x: jnp.atleast_1d(f_nonlinear(x[0])), state)

# --- Monte Carlo integrator (moment matching) ---
mc = gaussx.MonteCarloIntegrator(n_samples=50_000, key=jr.key(42))
res_mc = mc.integrate(lambda x: jnp.atleast_1d(f_nonlinear(x[0])), state)

# --- Pure MC: raw histogram as ground truth ---
key = jr.key(123)
n_pure = 200_000
x_samples = mu_x + jnp.sqrt(sigma2_x) * jr.normal(key, (n_pure,))
y_samples = f_vec(x_samples)
pure_mc_mean = jnp.mean(y_samples)
pure_mc_var = jnp.var(y_samples)
def _moments(res):
    m = float(res.state.mean[0])
    v = float(res.state.cov.as_matrix()[0, 0])
    return m, v


methods = {
    "Taylor 1st": _moments(res_t1),
    "Taylor 2nd": _moments(res_t2),
    "Taylor 2nd (var)": _moments(res_t2v),
    "Unscented": _moments(res_ut),
    "MC (50k)": _moments(res_mc),
    "Pure MC (200k)": (float(pure_mc_mean), float(pure_mc_var)),
}

print(f"{'Method':<18s} {'Mean':>8s} {'Std':>8s}")
print("-" * 36)
for name, (m, v) in methods.items():
    print(f"{name:<18s} {m:8.4f} {jnp.sqrt(v):8.4f}")
Method                 Mean      Std
------------------------------------
Taylor 1st           1.1586   0.1036
Taylor 2nd           0.5085   0.1036
Taylor 2nd (var)     0.5085   0.9251
Unscented            0.6085   0.8062
MC (50k)             0.7388   0.5023
Pure MC (200k)       0.7358   0.5074

Visualisation

The main plot shows:

  • Bottom x-axis: the input Gaussian p(x)p(x) as a filled curve.
  • Left y-axis: the nonlinear function f(x)f(x) (black curve).
  • Right side: the output Gaussian approximations from each method, and the true output histogram from pure MC.

This makes it easy to see how the input bell curve “warps” through ff and how each approximation captures (or misses) the resulting skew.

from matplotlib.patches import FancyArrowPatch
from scipy.stats import norm


fig, ax = plt.subplots(figsize=(9, 6))

# --- Nonlinear function (clip to x < x_right so it doesn't bleed) ---
x_fn = jnp.linspace(-4, 3.5, 400)
y_fn = f_vec(x_fn)
ax.plot(x_fn, y_fn, "k-", lw=2.5, label="$f(x)$", zorder=5)

# --- Input Gaussian (along x-axis) ---
x_pdf = norm.pdf(x_grid, loc=mu_x, scale=jnp.sqrt(sigma2_x))
# Scale so the peak is visually ~0.8 units tall on the y-axis
x_pdf_scaled = x_pdf / x_pdf.max() * 0.8
y_base = -2.0  # shift below the function
ax.fill_between(
    x_grid,
    y_base,
    y_base + x_pdf_scaled,
    color="C0",
    alpha=0.3,
    zorder=2,
)
ax.plot(x_grid, y_base + x_pdf_scaled, "C0-", lw=1.5, zorder=3)
ax.text(
    mu_x,
    y_base - 0.15,
    r"$p(x)$",
    ha="center",
    fontsize=12,
    color="C0",
)

# --- Output distributions (horizontal Gaussians) ---
colours = ["C1", "C2", "C6", "C3", "C4", "C5"]
x_right = 4.2
width = 1.0

y_range = jnp.linspace(-2.5, 3.0, 400)

for i, (name, (m, v)) in enumerate(methods.items()):
    s = jnp.sqrt(v)
    pdf_vals = norm.pdf(y_range, loc=m, scale=s)
    pdf_scaled = pdf_vals / pdf_vals.max() * width
    x_pos = x_right + i * 0.9

    ax.fill_betweenx(
        y_range,
        x_pos,
        x_pos + pdf_scaled,
        color=colours[i],
        alpha=0.35,
        zorder=2,
    )
    ax.plot(
        x_pos + pdf_scaled,
        y_range,
        color=colours[i],
        lw=1.5,
        zorder=3,
    )
    ax.plot(x_pos, m, "o", color=colours[i], ms=5, zorder=6)
    ax.text(
        x_pos + width * 0.5,
        3.3,
        name,
        ha="center",
        fontsize=8,
        color=colours[i],
        rotation=45,
        fontweight="bold",
    )

# --- Pure MC histogram (faint, behind the Gaussians) ---
hist_x_pos = x_right + 5 * 0.9
counts, bin_edges = jnp.histogram(y_samples, bins=80, density=True)
bin_centres = 0.5 * (bin_edges[:-1] + bin_edges[1:])
hist_scaled = counts / counts.max() * width
ax.barh(
    bin_centres,
    hist_scaled,
    left=hist_x_pos,
    height=bin_edges[1] - bin_edges[0],
    color="C5",
    alpha=0.2,
    zorder=1,
)

# --- Arrow showing the "propagation" ---
arrow = FancyArrowPatch(
    (mu_x, f_nonlinear(jnp.array(mu_x))),
    (x_right - 0.3, float(res_ut.state.mean[0])),
    arrowstyle="-|>",
    color="grey",
    lw=1.5,
    connectionstyle="arc3,rad=0.2",
    zorder=4,
)
ax.add_patch(arrow)
ax.text(
    (mu_x + x_right - 0.3) / 2,
    float(res_ut.state.mean[0]) + 0.5,
    "propagate\nuncertainty",
    ha="center",
    fontsize=9,
    color="grey",
)

# --- Formatting ---
ax.set_xlim(-4, x_right + 6 * 0.9 + 1.5)
ax.set_ylim(y_base - 0.4, 4.0)
ax.set_xlabel("x", fontsize=12)
ax.set_ylabel("y = f(x)", fontsize=12)
ax.set_title(
    r"Uncertainty propagation: input $\to$ nonlinear $f$ $\to$ output",
    fontsize=13,
)
ax.axhline(0, c="grey", lw=0.4)
ax.legend(loc="upper left", fontsize=9)
ax.grid(True, which="major", alpha=0.3)
ax.grid(True, which="minor", alpha=0.1)
ax.minorticks_on()
plt.tight_layout()
plt.show()
<Figure size 900x600 with 1 Axes>

Observations

  1. Taylor 1st order uses only the local slope at μ -- it produces a symmetric Gaussian whose width depends on f(μ)|f'(\mu)|. When ff curves strongly, this misses the mean shift.

  2. Taylor 2nd order adds a Hessian correction to the mean (μy+=12tr(HΣ)\mu_y \mathrel{+}= \tfrac{1}{2}\operatorname{tr}(H \Sigma)) and gets closer to the true mean, but the covariance is still first-order.

  3. Taylor 2nd (var) additionally corrects the covariance using fourth Gaussian moments. For this wiggly function the correction overshoots — a common failure mode for strongly nonlinear functions.

  4. Unscented transform evaluates ff at 2N+1=32N{+}1=3 deterministic sigma points and captures the nonlinearity much better -- even with just 3 function evaluations.

  5. Monte Carlo (moment matching) converges to the true moments with enough samples. With 50k samples, the mean and variance match the pure MC histogram closely.

  6. The pure MC histogram reveals that the true output distribution is slightly skewed -- something no Gaussian approximation can capture. All four methods project onto the best-fit Gaussian, which is the optimal thing to do for downstream linear-Gaussian inference (Kalman updates, variational inference, etc.).

Error comparison

ref_mean, ref_var = float(pure_mc_mean), float(pure_mc_var)

fig, axes = plt.subplots(1, 2, figsize=(10, 3.5))

names = list(methods.keys())[:-1]
mean_errors = [abs(methods[n][0] - ref_mean) for n in names]
var_errors = [abs(methods[n][1] - ref_var) for n in names]
colours_bar = colours[:5]

axes[0].bar(
    names,
    mean_errors,
    color=colours_bar,
    alpha=0.7,
    edgecolor="k",
    lw=0.5,
)
axes[0].set_ylabel("|mean error|")
axes[0].set_title("Mean error vs. pure MC reference")
axes[0].tick_params(axis="x", rotation=25)
axes[0].grid(True, which="major", alpha=0.3)
axes[0].grid(True, which="minor", alpha=0.1)
axes[0].minorticks_on()

axes[1].bar(
    names,
    var_errors,
    color=colours_bar,
    alpha=0.7,
    edgecolor="k",
    lw=0.5,
)
axes[1].set_ylabel("|variance error|")
axes[1].set_title("Variance error vs. pure MC reference")
axes[1].tick_params(axis="x", rotation=25)
axes[1].grid(True, which="major", alpha=0.3)
axes[1].grid(True, which="minor", alpha=0.1)
axes[1].minorticks_on()

plt.tight_layout()
plt.show()
<Figure size 1000x350 with 2 Axes>

When to use each method

MethodCostBest for
Taylor 1stO(N2)O(N^2) JacobianMildly nonlinear, real-time EKF
Taylor 2ndO(N3)O(N^3) HessianModerate nonlinearity, small NN
UnscentedO(N)O(N) function evalsGood default; no derivatives needed
Monte CarloO(S)O(S) function evalsComplex functions, large NN, reference

All four methods are available in gaussx via the AbstractIntegrator interface, so you can swap them with a single line change:

integrator = gaussx.TaylorIntegrator(order=1)    # or
integrator = gaussx.UnscentedIntegrator()         # or
integrator = gaussx.MonteCarloIntegrator(n_samples=10000, key=jr.key(0))

result = integrator.integrate(f, state)

Connection to GP prediction with uncertain inputs

In a trained GP with posterior mean m(x)m(x) and variance v(x)v(x), the total predictive variance under input uncertainty decomposes as (Girard et al., 2003, Eq. 2.40):

Var[y]  =  Ex[v(x)]noise / model variance  +  Varx[m(x)]input uncertainty.\operatorname{Var}[y_*] \;=\; \underbrace{\mathbb{E}_{x_*}[v(x_*)]}_{\text{noise / model variance}} \;+\; \underbrace{\operatorname{Var}_{x_*}[m(x_*)]}_{\text{input uncertainty}}.

The first term averages the GP’s own uncertainty over plausible inputs; the second is exactly the output variance computed by the methods above, applied to the posterior mean function mm.

For the squared-exponential (RBF) kernel, both expectations can be evaluated in closed form because the required Gaussian-times-kernel integrals are themselves Gaussian (Quinonero-Candela et al., 2003; Deisenroth, 2010, Appendix A). For other kernels (Matern, periodic, neural-network, etc.) or for compositions of GPs (deep GPs), the integrals are intractable and the numerical methods in this notebook are the standard approach.

The PILCO algorithm (Deisenroth & Rasmussen, 2011) chains these moment-matching steps across a planning horizon TT to perform gradient-based policy search in continuous state spaces, using the analytic RBF moments where possible and falling back to the unscented transform for the policy mapping.

References

  • Anderson, B. D. O. & Moore, J. B. (1979). Optimal Filtering. Prentice-Hall.
  • Athans, M., Wishner, R. P., & Bertolini, A. (1968). Suboptimal state estimation for continuous-time nonlinear systems from discrete noisy measurements. IEEE Trans. Automatic Control, 13(5), 504--514.
  • Deisenroth, M. P. (2010). Efficient Reinforcement Learning using Gaussian Processes. PhD thesis, Karlsruhe Institute of Technology.
  • Deisenroth, M. P. & Rasmussen, C. E. (2011). PILCO: A model-based and data-efficient approach to policy search. Proc. ICML, 465--472.
  • Girard, A., Rasmussen, C. E., Quinonero-Candela, J., & Murray-Smith, R. (2003). Gaussian process priors with uncertain inputs — application to multiple-step ahead time series forecasting. Proc. NeurIPS 15.
  • Julier, S. J. & Uhlmann, J. K. (1997). A new extension of the Kalman filter to nonlinear systems. Proc. AeroSense, 182--193.
  • Julier, S. J. & Uhlmann, J. K. (2004). Unscented filtering and nonlinear estimation. Proc. IEEE, 92(3), 401--422.
  • McHutchon, A. & Rasmussen, C. E. (2011). Gaussian process training with input noise. Proc. NeurIPS 24.
  • Quinonero-Candela, J., Girard, A., Larsen, J., & Rasmussen, C. E. (2003). Propagation of uncertainty in Bayesian kernel models — application to multiple-step ahead forecasting. Proc. ICASSP.
  • Wan, E. A. & van der Merwe, R. (2000). The unscented Kalman filter for nonlinear estimation. Proc. IEEE Adaptive Systems for Signal Processing, Communications, and Control Symposium, 153--158.