Skip to content

Estimation

Estimation is the public namespace for recursive filtering, smoothing, and the first optimization-based estimation primitives in Contrax.

The namespace includes:

  • kalman() and kalman_gain() for linear Gaussian filtering and steady-state estimator design
  • kalman_predict(), kalman_update(), kalman_step() for one-step runtime loops
  • ekf() and ukf() for nonlinear filtering
  • ekf_predict(), ekf_update(), ekf_step() for one-step nonlinear loops
  • rts() and uks() for offline smoothing
  • innovation_diagnostics(), likelihood_diagnostics(), ukf_diagnostics(), and smoother_diagnostics() for routine estimator-health summaries
  • mhe_objective(), mhe(), mhe_warm_start(), and soft_quadratic_penalty() for fixed-window optimization-based estimation
  • positive_exp(), positive_softplus(), spd_from_cholesky_raw(), and diagonal_spd() for lightweight constrained parameterization

Minimal Example

import jax.numpy as jnp
import contrax as cx

sys = cx.dss(
    jnp.array([[1.0, 0.1], [0.0, 1.0]]),
    jnp.array([[0.0], [0.1]]),
    jnp.eye(2),
    jnp.zeros((2, 1)),
    dt=0.1,
)
result = cx.kalman(
    sys,
    Q_noise=1e-3 * jnp.eye(2),
    R_noise=1e-2 * jnp.eye(2),
    ys=jnp.zeros((20, 2)),
)
Estimation pipeline from model and measurements through filtering, smoothing, and moving-horizon estimation
The estimation surface is layered: batch filters produce filtered trajectories, smoothers revisit them offline, and MHE uses the same model and noise assumptions in an explicit fixed-horizon objective.

Conventions

Contrax uses the update-first batch convention for kalman(), ekf(), and ukf(): (x0, P0) is the prior on x_0, and the first scan step updates with y_0.

Estimator Equations

Contrax uses an update-first batch convention: the pair (x0, P0) is the prior on x_0, the first step updates with y_0, and only then predicts the prior on x_1.

For linear discrete models, the state and observation maps are

\[ x_{k+1} = A x_k + B u_k + w_k, \qquad y_k = C x_k + D u_k + v_k \]

with process noise covariance Q and measurement noise covariance R.

Linear Kalman Filter

kalman() implements the standard discrete Gaussian filter:

\[ \hat{y}_k^- = C \hat{x}_k^- + D u_k, \qquad S_k = C P_k^- C^\top + R \]
\[ K_k = P_k^- C^\top S_k^{-1} \]
\[ \hat{x}_k^+ = \hat{x}_k^- + K_k \bigl(y_k - \hat{y}_k^-\bigr), \qquad P_k^+ = (I - K_k C) P_k^- \]
\[ \hat{x}_{k+1}^- = A \hat{x}_k^+ + B u_k, \qquad P_{k+1}^- = A P_k^+ A^\top + Q \]

Extended Kalman Filter

For nonlinear models

\[ x_{k+1} = f(t_k, x_k, u_k) + w_k, \qquad y_k = h(t_k, x_k, u_k) + v_k \]

ekf() linearizes the transition and observation maps with JAX Jacobians:

\[ H_k = \left.\frac{\partial h}{\partial x}\right|_{(t_k, \hat{x}_k^-, u_k)}, \qquad F_k = \left.\frac{\partial f}{\partial x}\right|_{(t_k, \hat{x}_k^+, u_k)} \]
\[ \hat{y}_k^- = h(t_k, \hat{x}_k^-, u_k), \qquad S_k = H_k P_k^- H_k^\top + R \]
\[ K_k = P_k^- H_k^\top S_k^{-1} \]
\[ \hat{x}_k^+ = \hat{x}_k^- + K_k \bigl(y_k - \hat{y}_k^-\bigr), \qquad P_k^+ = (I - K_k H_k) P_k^- \]
\[ \hat{x}_{k+1}^- = f(t_k, \hat{x}_k^+, u_k), \qquad P_{k+1}^- = F_k P_k^+ F_k^\top + Q \]

If you use the one-step ekf_update(..., num_iter>1) helper, the measurement map is relinearized around the current update iterate while keeping the prediction covariance fixed.

Unscented Kalman Filter

ukf() avoids local Jacobians and instead propagates sigma points. Given dimension \(n\) and scaling parameter \(\lambda\),

\[ \chi^{(0)} = \hat{x}, \qquad \chi^{(i)} = \hat{x} \pm \left[\sqrt{(n + \lambda) P}\right]_i \]

At the update stage, sigma points from the current prior \((\hat{x}_k^-, P_k^-)\) are pushed through the observation map:

\[ \gamma_k^{(i)} = h(t_k, \chi_k^{(i)}, u_k), \qquad \hat{y}_k^- = \sum_i W_i^{(m)} \gamma_k^{(i)} \]
\[ S_k = \sum_i W_i^{(c)} \bigl(\gamma_k^{(i)} - \hat{y}_k^-\bigr) \bigl(\gamma_k^{(i)} - \hat{y}_k^-\bigr)^\top + R \]
\[ P_{xy,k} = \sum_i W_i^{(c)} \bigl(\chi_k^{(i)} - \hat{x}_k^-\bigr) \bigl(\gamma_k^{(i)} - \hat{y}_k^-\bigr)^\top, \qquad K_k = P_{xy,k} S_k^{-1} \]
\[ \hat{x}_k^+ = \hat{x}_k^- + K_k \bigl(y_k - \hat{y}_k^-\bigr), \qquad P_k^+ = P_k^- - K_k S_k K_k^\top \]

For prediction, sigma points from \((\hat{x}_k^+, P_k^+)\) are pushed through the transition map:

\[ \xi_k^{(i)} = f(t_k, \chi_k^{(i)}, u_k), \qquad \hat{x}_{k+1}^- = \sum_i W_i^{(m)} \xi_k^{(i)} \]
\[ P_{k+1}^- = \sum_i W_i^{(c)} \bigl(\xi_k^{(i)} - \hat{x}_{k+1}^-\bigr) \bigl(\xi_k^{(i)} - \hat{x}_{k+1}^-\bigr)^\top + Q \]

Smoothers

rts() is the linear Rauch-Tung-Striebel smoother on top of KalmanResult:

\[ G_k = P_k^+ A^\top (P_{k+1}^-)^{-1} \]
\[ \hat{x}_k^s = \hat{x}_k^+ + G_k \bigl(\hat{x}_{k+1}^s - \hat{x}_{k+1}^-\bigr) \]
\[ P_k^s = P_k^+ + G_k \bigl(P_{k+1}^s - P_{k+1}^-\bigr) G_k^\top \]

uks() applies the same backward shape to the unscented filter, but replaces \(A P_k^+\) with an unscented cross-covariance:

\[ P_{x_k x_{k+1}} = \sum_i W_i^{(c)} \bigl(\chi_k^{(i)} - \hat{x}_k^+\bigr) \bigl(\xi_k^{(i)} - \hat{x}_{k+1}^-\bigr)^\top \]
\[ G_k = P_{x_k x_{k+1}} (P_{k+1}^-)^{-1} \]
\[ \hat{x}_k^s = \hat{x}_k^+ + G_k \bigl(\hat{x}_{k+1}^s - \hat{x}_{k+1}^-\bigr), \qquad P_k^s = P_k^+ + G_k \bigl(P_{k+1}^s - P_{k+1}^-\bigr) G_k^\top \]

Moving-Horizon Estimation

mhe_objective() and mhe() expose the fixed-window optimization form:

\[ J_{\mathrm{MHE}} = \|x_0 - x_{\mathrm{prior}}\|_{P^{-1}}^2 + \sum_{k=0}^{T-1}\|x_{k+1} - f(t_k, x_k, u_k)\|_{Q^{-1}}^2 + \sum_{k=0}^{T}\|y_k - h(t_k, x_k, u_k)\|_{R^{-1}}^2 \]

This is the optimization sibling of the recursive filters: the same model and noise assumptions appear, but they are expressed as an explicit horizon cost.

For rolling-window use:

  • mhe_warm_start(xs, ...) shifts a previous solution forward by one step to build the next initial guess
  • soft_quadratic_penalty(residuals, weight) is a small helper for soft state, terminal, or envelope penalties inside extra_cost

One-Step Helper Signatures

The one-step helpers follow the order: model, state (x, P), runtime data (u/y), noise matrices (Q/R).

# predict: advances state by one step
x_pred, P_pred = cx.ekf_predict(model, x, P, u, Q_noise)
x_pred, P_pred = cx.kalman_predict(sys, x, P, Q_noise, u=None)  # u optional

# update: fuses a new measurement
x, P, innov = cx.ekf_update(model, x_pred, P_pred, y, R_noise)
x, P, innov = cx.kalman_update(sys, x_pred, P_pred, y, R_noise, u=None)

# combined predict-update step
x, P, innov = cx.ekf_step(model, x, P, u, y, Q_noise, R_noise, observation=h)
x, P, innov = cx.kalman_step(sys, x, P, y, Q_noise, R_noise, u=None)

has_measurement=False can be passed to all step helpers to skip the update (useful for streams with missing samples inside jax.lax.scan).

Model Inputs

For nonlinear estimation, the main docs path is reusable system models such as NonlinearSystem or PHSSystem.

That keeps simulation, linearization, and estimation on one shared model object instead of redefining the same dynamics in several places.

The intended public workflow is:

sys_c = cx.nonlinear_system(f_continuous, output=h, dt=None)
sys_d = cx.sample_system(sys_c, dt=0.1)
result = cx.ukf(sys_d, Q_noise, R_noise, ys, us, x0, P0)
smoothed = cx.uks(sys_d, result, Q_noise, us)

ukf() returns UKFResult rather than KalmanResult because sigma-point filtering exposes additional public intermediates: predicted measurements, innovation covariances, per-step likelihood terms, and the one-step prediction quantities used by uks().

For continuous dynamics with discrete observations, sample_system() is the main library bridge. It integrates the continuous vector field over one sample interval and returns a discrete-time NonlinearSystem that can be passed directly to ekf() or ukf().

Transform Behavior

Use the one-step helpers when measurements arrive in a runtime loop or when the filter needs to live inside a larger jax.lax.scan.

The important transform contracts are:

  • kalman(), ekf(), and ukf() are batch scans over fixed-shape sequences
  • the one-step helpers expose missing-measurement handling without Python branching on traced values
  • mhe_objective() is a pure cost function over an explicit candidate trajectory, which makes it suitable for JAX-native optimization loops

Numerical Notes

kalman_gain() is the estimator dual of discrete LQR and returns the measurement-update gain used by kalman_update().

rts() requires the same Q_noise used by the forward kalman() pass so the smoother uses the same process model. Pass zeros for deterministic dynamics.

mhe_objective() is the lowest-level optimization-based estimator surface in the library. mhe() is a thin solver wrapper, not a full nonlinear programming framework.

Diagnostics

Contrax now includes a lightweight diagnostics layer for routine estimation health checks:

  • innovation_diagnostics(innovations, innovation_covariances) computes per-step normalized innovation squared (NIS), innovation norms, and covariance conditioning summaries
  • likelihood_diagnostics(log_likelihood_terms) summarizes innovation-form Gaussian log-likelihood terms
  • ukf_diagnostics(result) builds the standard innovation and likelihood summaries directly from a UKFResult
  • innovation_rms(result) gives a quick residual-scale summary for either KalmanResult or UKFResult
  • smoother_diagnostics(smoothed, filtered) checks whether the smoother satisfied its covariance-reduction guarantee (P_smooth ≤ P_filtered) and reports the worst-case violation, state correction magnitudes, and PSD health of the smoothed covariances. Accepts RTSResult paired with either KalmanResult or UKFResult.

These helpers are intentionally honest and lightweight. They are good for filter tuning, regression checks, and notebook-free workflow validation, but they are not a substitute for a full statistical consistency study.

Parameterization Helpers

Contrax also exposes a small constrained-parameterization layer for estimation and tuning workflows:

  • positive_exp(raw) and positive_softplus(raw) map unconstrained values to positive ones
  • spd_from_cholesky_raw(raw) maps an unconstrained square matrix to an SPD matrix through a lower-triangular Cholesky factor
  • diagonal_spd(raw_diagonal) builds a diagonal SPD matrix from unconstrained diagonal parameters

These are intentionally low-level building blocks rather than a parameter framework. They exist so downstream repos do not need to reinvent the same positivity and covariance-parameterization transforms.

contrax.estimation

contrax.estimation — Kalman, EKF, UKF, MHE, and RTS helpers.

ekf(model_or_f: DynamicsLike, Q_noise: Array, R_noise: Array, ys: Array, us: Array, x0: Array, P0: Array, *, observation: ObservationFn | None = None) -> KalmanResult

Run an extended Kalman filter for nonlinear dynamics and observations.

Uses automatic Jacobians of the transition and observation functions rather than requiring a hand-derived local linearization.

This is the main nonlinear recursive estimation primitive in the current library. It works well with JAX-defined model functions, but the API is intentionally simple and does not yet expose multiple EKF variants.

Parameters:

Name Type Description Default
model_or_f DynamicsLike

Nonlinear system model or dynamics function.

required
Q_noise Array

Process noise covariance. Shape: (n, n).

required
R_noise Array

Measurement noise covariance. Shape: (p, p).

required
ys Array

Measurement sequence. Shape: (T, p).

required
us Array

Input sequence. Shape: (T, m).

required
x0 Array

Initial mean. Shape: (n,).

required
P0 Array

Initial covariance. Shape: (n, n).

required
observation ObservationFn | None

Observation function h(x) for plain dynamics callables. Omit this when passing a system model with its own output map.

None

Returns:

Type Description
KalmanResult

KalmanResult: A bundle containing filtered means, covariances, and innovations.

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> def f(x, u):
...     return jnp.array([x[0] + 0.1 * x[1], x[1] + u[0]])
>>> def h(x):
...     return x[:1]
>>> result = cx.ekf(
...     f,
...     Q_noise=1e-3 * jnp.eye(2),
...     R_noise=1e-2 * jnp.eye(1),
...     ys=jnp.zeros((10, 1)),
...     us=jnp.zeros((10, 1)),
...     x0=jnp.zeros(2),
...     P0=jnp.eye(2),
...     observation=h,
... )
Source code in contrax/_ekf.py
def ekf(
    model_or_f: DynamicsLike,
    Q_noise: Array,
    R_noise: Array,
    ys: Array,
    us: Array,
    x0: Array,
    P0: Array,
    *,
    observation: ObservationFn | None = None,
) -> KalmanResult:
    """Run an extended Kalman filter for nonlinear dynamics and observations.

    Uses automatic Jacobians of the transition and observation functions rather
    than requiring a hand-derived local linearization.

    This is the main nonlinear recursive estimation primitive in the current
    library. It works well with JAX-defined model functions, but the API is
    intentionally simple and does not yet expose multiple EKF variants.

    Args:
        model_or_f: Nonlinear system model or dynamics function.
        Q_noise: Process noise covariance. Shape: `(n, n)`.
        R_noise: Measurement noise covariance. Shape: `(p, p)`.
        ys: Measurement sequence. Shape: `(T, p)`.
        us: Input sequence. Shape: `(T, m)`.
        x0: Initial mean. Shape: `(n,)`.
        P0: Initial covariance. Shape: `(n, n)`.
        observation: Observation function `h(x)` for plain dynamics callables.
            Omit this when passing a system model with its own output map.

    Returns:
        [KalmanResult][contrax.types.KalmanResult]: A bundle containing
            filtered means, covariances, and innovations.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> def f(x, u):
        ...     return jnp.array([x[0] + 0.1 * x[1], x[1] + u[0]])
        >>> def h(x):
        ...     return x[:1]
        >>> result = cx.ekf(
        ...     f,
        ...     Q_noise=1e-3 * jnp.eye(2),
        ...     R_noise=1e-2 * jnp.eye(1),
        ...     ys=jnp.zeros((10, 1)),
        ...     us=jnp.zeros((10, 1)),
        ...     x0=jnp.zeros(2),
        ...     P0=jnp.eye(2),
        ...     observation=h,
        ... )
    """
    if _is_system_model(model_or_f):
        if observation is not None:
            raise ValueError("ekf(sys, ...) does not accept observation=.")
        obs = model_or_f
    else:
        if observation is None:
            raise ValueError(
                "ekf(f, ..., observation=...) requires an observation function."
            )
        obs = observation

    dt = _system_dt(model_or_f)
    ts = jnp.arange(ys.shape[0], dtype=x0.dtype)
    if dt is not None:
        ts = ts * dt

    # Update-then-predict: x0/P0 is the prior on x_0.
    def update_predict(carry, inputs):
        x_prior, P_prior = carry
        y, u, t = inputs
        x_post, P_post, innov = ekf_update(
            obs,
            x_prior,
            P_prior,
            y,
            R_noise,
            u=u,
            t=t,
        )
        x_next, P_next = ekf_predict(model_or_f, x_post, P_post, u, Q_noise, t=t)
        return (x_next, P_next), (x_post, P_post, innov)

    _, (x_hats, Ps, innovations) = jax.lax.scan(update_predict, (x0, P0), (ys, us, ts))
    return KalmanResult(x_hat=x_hats, P=Ps, innovations=innovations)

ekf_predict(model_or_f: DynamicsLike, x: Array, P: Array, u: Array, Q_noise: Array, *, t: Array | float = 0.0) -> tuple[Array, Array]

Predict one extended Kalman filter step.

Parameters:

Name Type Description Default
model_or_f DynamicsLike

Nonlinear system model or dynamics function mapping (x, u) to the next state.

required
x Array

Filtered mean at the previous time step. Shape: (n,).

required
P Array

Filtered covariance at the previous time step. Shape: (n, n).

required
u Array

Current input. Shape: (m,).

required
Q_noise Array

Process noise covariance. Shape: (n, n).

required
t Array | float

Current sample time.

0.0

Returns:

Type Description
tuple[Array, Array]

Tuple (x_pred, P_pred) with the predicted mean and covariance.

Source code in contrax/_ekf.py
def ekf_predict(
    model_or_f: DynamicsLike,
    x: Array,
    P: Array,
    u: Array,
    Q_noise: Array,
    *,
    t: Array | float = 0.0,
) -> tuple[Array, Array]:
    """Predict one extended Kalman filter step.

    Args:
        model_or_f: Nonlinear system model or dynamics function mapping
            `(x, u)` to the next state.
        x: Filtered mean at the previous time step. Shape: `(n,)`.
        P: Filtered covariance at the previous time step. Shape: `(n, n)`.
        u: Current input. Shape: `(m,)`.
        Q_noise: Process noise covariance. Shape: `(n, n)`.
        t: Current sample time.

    Returns:
        Tuple `(x_pred, P_pred)` with the predicted mean and covariance.
    """
    f = _coerce_dynamics(model_or_f)
    x_pred = f(x, u, t)
    F = jax.jacfwd(f, argnums=0)(x, u, t)
    P_pred = _symmetrize(F @ P @ F.T + Q_noise)
    return x_pred, P_pred

ekf_step(model_or_f: DynamicsLike, x: Array, P: Array, u: Array, y: Array, Q_noise: Array, R_noise: Array, *, t: Array | float = 0.0, has_measurement: Array | bool = True, num_iter: int = 1, observation: ObservationFn | None = None) -> tuple[Array, Array, Array]

Run one predict-update step of an extended Kalman filter.

Parameters:

Name Type Description Default
model_or_f DynamicsLike

Nonlinear system model or dynamics function.

required
x Array

Filtered mean at the previous time step. Shape: (n,).

required
P Array

Filtered covariance at the previous time step. Shape: (n, n).

required
u Array

Current input. Shape: (m,).

required
y Array

Measurement vector. Shape: (p,).

required
Q_noise Array

Process noise covariance. Shape: (n, n).

required
R_noise Array

Measurement noise covariance. Shape: (p, p).

required
has_measurement Array | bool

Whether to apply the measurement update.

True
num_iter int

Number of iterated-EKF measurement linearization passes.

1
observation ObservationFn | None

Observation function h(x) for plain dynamics callables. Omit this when passing a system model with its own output map.

None

Returns:

Type Description
tuple[Array, Array, Array]

Tuple (x, P, innovation) after the step.

Source code in contrax/_ekf.py
def ekf_step(
    model_or_f: DynamicsLike,
    x: Array,
    P: Array,
    u: Array,
    y: Array,
    Q_noise: Array,
    R_noise: Array,
    *,
    t: Array | float = 0.0,
    has_measurement: Array | bool = True,
    num_iter: int = 1,
    observation: ObservationFn | None = None,
) -> tuple[Array, Array, Array]:
    """Run one predict-update step of an extended Kalman filter.

    Args:
        model_or_f: Nonlinear system model or dynamics function.
        x: Filtered mean at the previous time step. Shape: `(n,)`.
        P: Filtered covariance at the previous time step. Shape: `(n, n)`.
        u: Current input. Shape: `(m,)`.
        y: Measurement vector. Shape: `(p,)`.
        Q_noise: Process noise covariance. Shape: `(n, n)`.
        R_noise: Measurement noise covariance. Shape: `(p, p)`.
        has_measurement: Whether to apply the measurement update.
        num_iter: Number of iterated-EKF measurement linearization passes.
        observation: Observation function `h(x)` for plain dynamics callables.
            Omit this when passing a system model with its own output map.

    Returns:
        Tuple `(x, P, innovation)` after the step.
    """
    if _is_system_model(model_or_f):
        if observation is not None:
            raise ValueError("ekf_step(sys, ...) does not accept observation=.")
        obs = model_or_f
    else:
        if observation is None:
            raise ValueError(
                "ekf_step(f, ..., observation=...) requires an observation function."
            )
        obs = observation
    x_pred, P_pred = ekf_predict(model_or_f, x, P, u, Q_noise, t=t)
    return ekf_update(
        obs,
        x_pred,
        P_pred,
        y,
        R_noise,
        u=u,
        t=t,
        has_measurement=has_measurement,
        num_iter=num_iter,
    )

ekf_update(model_or_h: ObservationLike, x_pred: Array, P_pred: Array, y: Array, R_noise: Array, *, u: Array | None = None, t: Array | float = 0.0, has_measurement: Array | bool = True, num_iter: int = 1) -> tuple[Array, Array, Array]

Update one extended Kalman filter step.

Parameters:

Name Type Description Default
model_or_h ObservationLike

Nonlinear system model or observation function.

required
x_pred Array

Predicted mean. Shape: (n,).

required
P_pred Array

Predicted covariance. Shape: (n, n).

required
y Array

Measurement vector. Shape: (p,).

required
R_noise Array

Measurement noise covariance. Shape: (p, p).

required
u Array | None

Current input used by the observation model. Shape: (m,).

None
t Array | float

Current sample time.

0.0
has_measurement Array | bool

Whether to apply the measurement update.

True
num_iter int

Number of iterated-EKF measurement linearization passes. 1 gives the ordinary EKF update.

1

Returns:

Type Description
Array

Tuple (x, P, innovation) with the filtered mean, filtered covariance,

Array

and final measurement residual. Missing-measurement steps return the

Array

prediction and a zero innovation.

Source code in contrax/_ekf.py
def ekf_update(
    model_or_h: ObservationLike,
    x_pred: Array,
    P_pred: Array,
    y: Array,
    R_noise: Array,
    *,
    u: Array | None = None,
    t: Array | float = 0.0,
    has_measurement: Array | bool = True,
    num_iter: int = 1,
) -> tuple[Array, Array, Array]:
    """Update one extended Kalman filter step.

    Args:
        model_or_h: Nonlinear system model or observation function.
        x_pred: Predicted mean. Shape: `(n,)`.
        P_pred: Predicted covariance. Shape: `(n, n)`.
        y: Measurement vector. Shape: `(p,)`.
        R_noise: Measurement noise covariance. Shape: `(p, p)`.
        u: Current input used by the observation model. Shape: `(m,)`.
        t: Current sample time.
        has_measurement: Whether to apply the measurement update.
        num_iter: Number of iterated-EKF measurement linearization passes.
            `1` gives the ordinary EKF update.

    Returns:
        Tuple `(x, P, innovation)` with the filtered mean, filtered covariance,
        and final measurement residual. Missing-measurement steps return the
        prediction and a zero innovation.
    """
    if num_iter < 1:
        raise ValueError("num_iter must be at least 1")

    n = x_pred.shape[0]
    if u is None:
        u = jnp.zeros((0,), dtype=x_pred.dtype)
    h = _coerce_observation(
        model_or_h,
        None if _is_system_model(model_or_h) else model_or_h,
    )

    # Reference: Jazwinski (1970), "Stochastic Processes and Filtering
    # Theory"; the iterated EKF relinearizes the measurement model around the
    # current update iterate while keeping the prediction covariance fixed.
    def one_iteration(x_lin, _):
        H = jax.jacfwd(h, argnums=0)(x_lin, u, t)
        S = H @ P_pred @ H.T + R_noise
        K = jnp.linalg.solve(S.T, (P_pred @ H.T).T).T
        innov = y - h(x_lin, u, t) + H @ (x_lin - x_pred)
        x_next = x_pred + K @ innov
        return x_next, (x_next, H, K, innov)

    x_final, (_, Hs, Ks, innovations) = jax.lax.scan(
        one_iteration,
        x_pred,
        xs=None,
        length=num_iter,
    )
    H_final = Hs[-1]
    K_final = Ks[-1]
    innov_final = innovations[-1]
    P_final = _symmetrize((jnp.eye(n, dtype=P_pred.dtype) - K_final @ H_final) @ P_pred)

    def apply_update():
        return x_final, P_final, innov_final

    def skip_update():
        return x_pred, P_pred, jnp.zeros_like(y)

    return jax.lax.cond(
        _measurement_flag(has_measurement),
        apply_update,
        skip_update,
    )

innovation_diagnostics(innovations: Array, innovation_covariances: Array) -> InnovationDiagnostics

Summarize innovation magnitude and covariance conditioning.

The returned NIS values are the per-step normalized innovation squared terms v_k^T S_k^{-1} v_k. They are useful for routine filter health inspection, but Contrax does not present them as a formal standalone consistency certificate.

Source code in contrax/_estimation_diagnostics.py
def innovation_diagnostics(
    innovations: Array,
    innovation_covariances: Array,
) -> InnovationDiagnostics:
    """Summarize innovation magnitude and covariance conditioning.

    The returned NIS values are the per-step normalized innovation squared
    terms `v_k^T S_k^{-1} v_k`. They are useful for routine filter health
    inspection, but Contrax does not present them as a formal standalone
    consistency certificate.
    """

    def step_nis(innovation, covariance):
        solved = jnp.linalg.solve(covariance, innovation[:, None])
        return (innovation[None, :] @ solved).squeeze()

    nis = jax.vmap(step_nis)(innovations, innovation_covariances)
    innovation_norms = jnp.linalg.norm(innovations, axis=-1)
    innovation_cov_condition_numbers = jax.vmap(jnp.linalg.cond)(innovation_covariances)
    all_values = (
        jnp.isfinite(nis)
        & jnp.isfinite(innovation_norms)
        & jnp.isfinite(innovation_cov_condition_numbers)
    )
    return InnovationDiagnostics(
        nis=nis,
        innovation_norms=innovation_norms,
        innovation_cov_condition_numbers=innovation_cov_condition_numbers,
        mean_nis=jnp.mean(nis),
        max_nis=jnp.max(nis),
        mean_innovation_norm=jnp.mean(innovation_norms),
        max_innovation_norm=jnp.max(innovation_norms),
        max_innovation_cov_condition_number=jnp.max(innovation_cov_condition_numbers),
        nonfinite=jnp.logical_not(jnp.all(all_values)),
    )

innovation_rms(result: KalmanResult | UKFResult) -> Array

Return the RMS innovation magnitude for a filtered result.

Source code in contrax/_estimation_diagnostics.py
def innovation_rms(result: KalmanResult | UKFResult) -> Array:
    """Return the RMS innovation magnitude for a filtered result."""
    return jnp.sqrt(jnp.mean(result.innovations**2))

kalman(sys: DiscLTI, Q_noise: Array, R_noise: Array, ys: Array, x0: Array | None = None, P0: Array | None = None) -> KalmanResult

Run a linear discrete-time Kalman filter.

Filters a measurement sequence for a DiscLTI system using the standard predict-update recursion expressed as a pure lax.scan.

This is the baseline linear-Gaussian estimation path in Contrax. Innovation gains are computed with linear solves rather than explicit inverses, and the scan structure makes small differentiable covariance-tuning loops practical.

Parameters:

Name Type Description Default
sys DiscLTI

Discrete-time linear system.

required
Q_noise Array

Process noise covariance. Shape: (n, n).

required
R_noise Array

Measurement noise covariance. Shape: (p, p).

required
ys Array

Measurement sequence. Shape: (T, p).

required
x0 Array | None

Optional initial mean. Shape: (n,). Defaults to zeros.

None
P0 Array | None

Optional initial covariance. Shape: (n, n). Defaults to the identity matrix.

None

Returns:

Type Description
KalmanResult

KalmanResult: A bundle containing filtered means x_hat, filtered covariances P, and measurement innovations innovations.

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys = cx.dss(
...     jnp.array([[1.0, 0.1], [0.0, 1.0]]),
...     jnp.array([[0.0], [0.1]]),
...     jnp.eye(2),
...     jnp.zeros((2, 1)),
...     dt=0.1,
... )
>>> result = cx.kalman(
...     sys,
...     Q_noise=1e-3 * jnp.eye(2),
...     R_noise=1e-2 * jnp.eye(2),
...     ys=jnp.zeros((20, 2)),
... )
Source code in contrax/_kalman.py
def kalman(
    sys: DiscLTI,
    Q_noise: Array,  # (n, n) process noise covariance
    R_noise: Array,  # (p, p) measurement noise covariance
    ys: Array,  # (T, p) measurement sequence
    x0: Array | None = None,
    P0: Array | None = None,
) -> KalmanResult:
    """Run a linear discrete-time Kalman filter.

    Filters a measurement sequence for a
    [DiscLTI][contrax.systems.DiscLTI] system using the standard
    predict-update recursion expressed as a pure `lax.scan`.

    This is the baseline linear-Gaussian estimation path in Contrax. Innovation
    gains are computed with linear solves rather than explicit inverses, and
    the scan structure makes small differentiable covariance-tuning loops
    practical.

    Args:
        sys: Discrete-time linear system.
        Q_noise: Process noise covariance. Shape: `(n, n)`.
        R_noise: Measurement noise covariance. Shape: `(p, p)`.
        ys: Measurement sequence. Shape: `(T, p)`.
        x0: Optional initial mean. Shape: `(n,)`. Defaults to zeros.
        P0: Optional initial covariance. Shape: `(n, n)`. Defaults to the
            identity matrix.

    Returns:
        [KalmanResult][contrax.types.KalmanResult]: A bundle containing
            filtered means `x_hat`, filtered covariances `P`, and measurement
            innovations `innovations`.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> sys = cx.dss(
        ...     jnp.array([[1.0, 0.1], [0.0, 1.0]]),
        ...     jnp.array([[0.0], [0.1]]),
        ...     jnp.eye(2),
        ...     jnp.zeros((2, 1)),
        ...     dt=0.1,
        ... )
        >>> result = cx.kalman(
        ...     sys,
        ...     Q_noise=1e-3 * jnp.eye(2),
        ...     R_noise=1e-2 * jnp.eye(2),
        ...     ys=jnp.zeros((20, 2)),
        ... )
    """
    n = sys.A.shape[0]
    if x0 is None:
        x0 = jnp.zeros(n, dtype=ys.dtype)
    if P0 is None:
        P0 = jnp.eye(n, dtype=ys.dtype)

    # Update-then-predict: x0/P0 is the prior on x_0.
    # Each scan step updates with y_k first (using the current prior), then
    # predicts forward to give the prior on x_{k+1}.
    def update_predict(carry, y):
        x_prior, P_prior = carry
        x_post, P_post, innov = kalman_update(sys, x_prior, P_prior, y, R_noise)
        x_next, P_next = kalman_predict(sys, x_post, P_post, Q_noise)
        return (x_next, P_next), (x_post, P_post, innov)

    _, (x_hats, Ps, innovations) = jax.lax.scan(update_predict, (x0, P0), ys)
    return KalmanResult(x_hat=x_hats, P=Ps, innovations=innovations)

kalman_gain(sys: DiscLTI, Q_noise: Array, R_noise: Array) -> KalmanGainResult

Design a steady-state Kalman filter gain for a discrete LTI system.

This is the estimator dual of discrete LQR. It solves the DARE for (A.T, C.T, Q_noise, R_noise) and returns the measurement-update gain used by kalman_update().

Parameters:

Name Type Description Default
sys DiscLTI

Discrete-time linear system.

required
Q_noise Array

Process noise covariance. Shape: (n, n).

required
R_noise Array

Measurement noise covariance. Shape: (p, p).

required

Returns:

Type Description
KalmanGainResult

KalmanGainResult: A bundle with update gain K, steady-state predicted covariance P, and estimator error-dynamics poles.

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys = cx.dss(
...     jnp.array([[1.0, 0.1], [0.0, 1.0]]),
...     jnp.array([[0.0], [0.1]]),
...     jnp.eye(2),
...     jnp.zeros((2, 1)),
...     dt=0.1,
... )
>>> design = cx.kalman_gain(sys, 1e-3 * jnp.eye(2), 1e-2 * jnp.eye(2))
Source code in contrax/_kalman.py
def kalman_gain(
    sys: DiscLTI,
    Q_noise: Array,
    R_noise: Array,
) -> KalmanGainResult:
    """Design a steady-state Kalman filter gain for a discrete LTI system.

    This is the estimator dual of discrete LQR. It solves the DARE for
    `(A.T, C.T, Q_noise, R_noise)` and returns the measurement-update gain used
    by `kalman_update()`.

    Args:
        sys: Discrete-time linear system.
        Q_noise: Process noise covariance. Shape: `(n, n)`.
        R_noise: Measurement noise covariance. Shape: `(p, p)`.

    Returns:
        [KalmanGainResult][contrax.types.KalmanGainResult]: A bundle with
            update gain `K`, steady-state predicted covariance `P`, and
            estimator error-dynamics poles.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> sys = cx.dss(
        ...     jnp.array([[1.0, 0.1], [0.0, 1.0]]),
        ...     jnp.array([[0.0], [0.1]]),
        ...     jnp.eye(2),
        ...     jnp.zeros((2, 1)),
        ...     dt=0.1,
        ... )
        >>> design = cx.kalman_gain(sys, 1e-3 * jnp.eye(2), 1e-2 * jnp.eye(2))
    """
    P_pred = dare(sys.A.T, sys.C.T, Q_noise, R_noise).S
    innovation_cov = sys.C @ P_pred @ sys.C.T + R_noise
    K = jnp.linalg.solve(innovation_cov.T, (P_pred @ sys.C.T).T).T
    poles = jnp.linalg.eigvals(sys.A - sys.A @ K @ sys.C)
    return KalmanGainResult(K=K, P=P_pred, poles=poles)

kalman_predict(sys: DiscLTI, x: Array, P: Array, Q_noise: Array, u: Array | None = None) -> tuple[Array, Array]

Predict one linear Kalman filter step.

Parameters:

Name Type Description Default
sys DiscLTI

Discrete-time linear system.

required
x Array

Filtered mean at the previous time step. Shape: (n,).

required
P Array

Filtered covariance at the previous time step. Shape: (n, n).

required
Q_noise Array

Process noise covariance. Shape: (n, n).

required
u Array | None

Optional input for the transition. Shape: (m,). If omitted, the input term is zero.

None

Returns:

Type Description
tuple[Array, Array]

Tuple (x_pred, P_pred) with the predicted mean and covariance.

Source code in contrax/_kalman.py
def kalman_predict(
    sys: DiscLTI,
    x: Array,
    P: Array,
    Q_noise: Array,
    u: Array | None = None,
) -> tuple[Array, Array]:
    """Predict one linear Kalman filter step.

    Args:
        sys: Discrete-time linear system.
        x: Filtered mean at the previous time step. Shape: `(n,)`.
        P: Filtered covariance at the previous time step. Shape: `(n, n)`.
        Q_noise: Process noise covariance. Shape: `(n, n)`.
        u: Optional input for the transition. Shape: `(m,)`. If omitted, the
            input term is zero.

    Returns:
        Tuple `(x_pred, P_pred)` with the predicted mean and covariance.
    """
    u = _zero_input(sys, x) if u is None else u
    x_pred = sys.A @ x + sys.B @ u
    P_pred = _symmetrize(sys.A @ P @ sys.A.T + Q_noise)
    return x_pred, P_pred

kalman_step(sys: DiscLTI, x: Array, P: Array, y: Array, Q_noise: Array, R_noise: Array, u: Array | None = None, *, has_measurement: Array | bool = True) -> tuple[Array, Array, Array]

Run one predict-update step of a linear Kalman filter.

This is the online counterpart to kalman(). It is a pure array function, so service loops can call it directly and batch workflows can place it inside jax.lax.scan.

Parameters:

Name Type Description Default
sys DiscLTI

Discrete-time linear system.

required
x Array

Filtered mean at the previous time step. Shape: (n,).

required
P Array

Filtered covariance at the previous time step. Shape: (n, n).

required
y Array

Measurement vector. Shape: (p,).

required
Q_noise Array

Process noise covariance. Shape: (n, n).

required
R_noise Array

Measurement noise covariance. Shape: (p, p).

required
u Array | None

Optional input. Shape: (m,).

None
has_measurement Array | bool

Whether to apply the measurement update.

True

Returns:

Type Description
tuple[Array, Array, Array]

Tuple (x, P, innovation) after the step.

Source code in contrax/_kalman.py
def kalman_step(
    sys: DiscLTI,
    x: Array,
    P: Array,
    y: Array,
    Q_noise: Array,
    R_noise: Array,
    u: Array | None = None,
    *,
    has_measurement: Array | bool = True,
) -> tuple[Array, Array, Array]:
    """Run one predict-update step of a linear Kalman filter.

    This is the online counterpart to `kalman()`. It is a pure array function,
    so service loops can call it directly and batch workflows can place it
    inside `jax.lax.scan`.

    Args:
        sys: Discrete-time linear system.
        x: Filtered mean at the previous time step. Shape: `(n,)`.
        P: Filtered covariance at the previous time step. Shape: `(n, n)`.
        y: Measurement vector. Shape: `(p,)`.
        Q_noise: Process noise covariance. Shape: `(n, n)`.
        R_noise: Measurement noise covariance. Shape: `(p, p)`.
        u: Optional input. Shape: `(m,)`.
        has_measurement: Whether to apply the measurement update.

    Returns:
        Tuple `(x, P, innovation)` after the step.
    """
    x_pred, P_pred = kalman_predict(sys, x, P, Q_noise, u)
    return kalman_update(
        sys,
        x_pred,
        P_pred,
        y,
        R_noise,
        u,
        has_measurement=has_measurement,
    )

kalman_update(sys: DiscLTI, x_pred: Array, P_pred: Array, y: Array, R_noise: Array, u: Array | None = None, *, has_measurement: Array | bool = True) -> tuple[Array, Array, Array]

Update one linear Kalman filter step.

has_measurement=False returns the predicted state and covariance without applying a measurement update. The flag may be a JAX scalar boolean, making this helper usable inside jax.lax.scan for streams with missing samples.

Parameters:

Name Type Description Default
sys DiscLTI

Discrete-time linear system.

required
x_pred Array

Predicted mean. Shape: (n,).

required
P_pred Array

Predicted covariance. Shape: (n, n).

required
y Array

Measurement vector. Shape: (p,).

required
R_noise Array

Measurement noise covariance. Shape: (p, p).

required
u Array | None

Optional input for direct-feedthrough measurement models. Shape: (m,). If omitted, the feedthrough term is zero.

None
has_measurement Array | bool

Whether to apply the measurement update.

True

Returns:

Type Description
Array

Tuple (x, P, innovation) with the filtered mean, filtered covariance,

Array

and measurement residual. Missing-measurement steps return a zero

Array

innovation with the same shape as y.

Source code in contrax/_kalman.py
def kalman_update(
    sys: DiscLTI,
    x_pred: Array,
    P_pred: Array,
    y: Array,
    R_noise: Array,
    u: Array | None = None,
    *,
    has_measurement: Array | bool = True,
) -> tuple[Array, Array, Array]:
    """Update one linear Kalman filter step.

    `has_measurement=False` returns the predicted state and covariance without
    applying a measurement update. The flag may be a JAX scalar boolean, making
    this helper usable inside `jax.lax.scan` for streams with missing samples.

    Args:
        sys: Discrete-time linear system.
        x_pred: Predicted mean. Shape: `(n,)`.
        P_pred: Predicted covariance. Shape: `(n, n)`.
        y: Measurement vector. Shape: `(p,)`.
        R_noise: Measurement noise covariance. Shape: `(p, p)`.
        u: Optional input for direct-feedthrough measurement models. Shape:
            `(m,)`. If omitted, the feedthrough term is zero.
        has_measurement: Whether to apply the measurement update.

    Returns:
        Tuple `(x, P, innovation)` with the filtered mean, filtered covariance,
        and measurement residual. Missing-measurement steps return a zero
        innovation with the same shape as `y`.
    """
    u = _zero_input(sys, x_pred) if u is None else u
    n = x_pred.shape[0]
    y_pred = sys.C @ x_pred + sys.D @ u
    S = sys.C @ P_pred @ sys.C.T + R_noise
    K = jnp.linalg.solve(S.T, (P_pred @ sys.C.T).T).T
    innov = y - y_pred
    x_new = x_pred + K @ innov
    P_new = _symmetrize((jnp.eye(n, dtype=P_pred.dtype) - K @ sys.C) @ P_pred)

    def apply_update():
        return x_new, P_new, innov

    def skip_update():
        return x_pred, P_pred, jnp.zeros_like(y)

    return jax.lax.cond(
        _measurement_flag(has_measurement),
        apply_update,
        skip_update,
    )

likelihood_diagnostics(log_likelihood_terms: Array) -> LikelihoodDiagnostics

Summarize per-step innovation-form log-likelihood terms.

Source code in contrax/_estimation_diagnostics.py
def likelihood_diagnostics(log_likelihood_terms: Array) -> LikelihoodDiagnostics:
    """Summarize per-step innovation-form log-likelihood terms."""
    return LikelihoodDiagnostics(
        log_likelihood_terms=log_likelihood_terms,
        total_log_likelihood=jnp.sum(log_likelihood_terms),
        mean_log_likelihood=jnp.mean(log_likelihood_terms),
        min_log_likelihood=jnp.min(log_likelihood_terms),
        max_log_likelihood=jnp.max(log_likelihood_terms),
        nonfinite=jnp.logical_not(jnp.all(jnp.isfinite(log_likelihood_terms))),
    )

mhe(f: Callable[..., Array], h: Callable[..., Array], xs_init: Array, us: Array, ys: Array, x_prior: Array, P_prior: Array, Q_noise: Array, R_noise: Array, params: object | None = None, extra_cost: Callable[..., Array] | None = None, solver: AbstractMinimiser | None = None, max_steps: int = 256) -> MHEResult

Solve a fixed-window moving-horizon-estimation problem.

Minimises mhe_objective over the candidate state trajectory xs_init using an optimistix solver. The default solver is optx.LBFGS(rtol=1e-6, atol=1e-6), which is well-suited to the deterministic, smooth, fixed-horizon objectives that arise in MHE.

The optimised terminal state x_hat = xs[-1] is the natural input to a downstream controller or observer. For a linear-Gaussian model, the full trajectory xs matches the RTS smoother output within solver tolerance.

Parameters:

Name Type Description Default
f Callable[..., Array]

Transition function. Called as f(x, u) when params=None, and f(x, u, params) otherwise.

required
h Callable[..., Array]

Observation function. Called as h(x) when params=None, and h(x, params) otherwise.

required
xs_init Array

Initial trajectory guess. Shape: (T, n). Using the Kalman filtered means as a warm start is recommended.

required
us Array

Input sequence. Shape: (T - 1, m).

required
ys Array

Measurement sequence aligned with xs. Shape: (T, p).

required
x_prior Array

Prior mean for the start of the window. Shape: (n,).

required
P_prior Array

Arrival covariance. Shape: (n, n).

required
Q_noise Array

Process noise covariance. Shape: (n, n).

required
R_noise Array

Measurement noise covariance. Shape: (p, p).

required
params object | None

Optional parameters passed to f, h, and extra_cost.

None
extra_cost Callable[..., Array] | None

Optional callable adding additional smooth costs. Called as extra_cost(xs, us, ys) when params=None, and extra_cost(xs, us, ys, params) otherwise.

None
solver AbstractMinimiser | None

Optimistix minimiser. Defaults to optx.LBFGS(rtol=1e-6, atol=1e-6).

None
max_steps int

Maximum solver iterations. Defaults to 256.

256

Returns:

Type Description
MHEResult

MHEResult: Estimated trajectory, terminal state, final cost, and convergence flag.

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys = cx.dss(
...     jnp.array([[0.8]]), jnp.zeros((1, 1)),
...     jnp.array([[1.0]]), jnp.zeros((1, 1)), dt=1.0,
... )
>>> f = lambda x, u: sys.A @ x
>>> h = lambda x: sys.C @ x
>>> ys = jnp.array([[0.1], [0.4], [0.6], [0.5], [0.4]])
>>> result = cx.mhe(
...     f, h,
...     xs_init=jnp.zeros((5, 1)),
...     us=jnp.zeros((4, 1)),
...     ys=ys,
...     x_prior=jnp.zeros(1),
...     P_prior=jnp.eye(1),
...     Q_noise=0.05 * jnp.eye(1),
...     R_noise=0.2 * jnp.eye(1),
... )
Source code in contrax/_mhe.py
def mhe(
    f: Callable[..., Array],
    h: Callable[..., Array],
    xs_init: Array,
    us: Array,
    ys: Array,
    x_prior: Array,
    P_prior: Array,
    Q_noise: Array,
    R_noise: Array,
    params: object | None = None,
    extra_cost: Callable[..., Array] | None = None,
    solver: optx.AbstractMinimiser | None = None,
    max_steps: int = 256,
) -> MHEResult:
    """Solve a fixed-window moving-horizon-estimation problem.

    Minimises `mhe_objective` over the candidate state trajectory `xs_init`
    using an `optimistix` solver.  The default solver is
    `optx.LBFGS(rtol=1e-6, atol=1e-6)`, which is well-suited to the
    deterministic, smooth, fixed-horizon objectives that arise in MHE.

    The optimised terminal state `x_hat = xs[-1]` is the natural input to a
    downstream controller or observer.  For a linear-Gaussian model, the
    full trajectory `xs` matches the RTS smoother output within solver
    tolerance.

    Args:
        f: Transition function.  Called as `f(x, u)` when `params=None`,
            and `f(x, u, params)` otherwise.
        h: Observation function.  Called as `h(x)` when `params=None`,
            and `h(x, params)` otherwise.
        xs_init: Initial trajectory guess.  Shape: `(T, n)`.  Using the
            Kalman filtered means as a warm start is recommended.
        us: Input sequence.  Shape: `(T - 1, m)`.
        ys: Measurement sequence aligned with `xs`.  Shape: `(T, p)`.
        x_prior: Prior mean for the start of the window.  Shape: `(n,)`.
        P_prior: Arrival covariance.  Shape: `(n, n)`.
        Q_noise: Process noise covariance.  Shape: `(n, n)`.
        R_noise: Measurement noise covariance.  Shape: `(p, p)`.
        params: Optional parameters passed to `f`, `h`, and `extra_cost`.
        extra_cost: Optional callable adding additional smooth costs.
            Called as `extra_cost(xs, us, ys)` when `params=None`, and
            `extra_cost(xs, us, ys, params)` otherwise.
        solver: Optimistix minimiser.  Defaults to
            `optx.LBFGS(rtol=1e-6, atol=1e-6)`.
        max_steps: Maximum solver iterations.  Defaults to 256.

    Returns:
        [MHEResult][contrax.types.MHEResult]: Estimated trajectory, terminal
            state, final cost, and convergence flag.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> sys = cx.dss(
        ...     jnp.array([[0.8]]), jnp.zeros((1, 1)),
        ...     jnp.array([[1.0]]), jnp.zeros((1, 1)), dt=1.0,
        ... )
        >>> f = lambda x, u: sys.A @ x
        >>> h = lambda x: sys.C @ x
        >>> ys = jnp.array([[0.1], [0.4], [0.6], [0.5], [0.4]])
        >>> result = cx.mhe(
        ...     f, h,
        ...     xs_init=jnp.zeros((5, 1)),
        ...     us=jnp.zeros((4, 1)),
        ...     ys=ys,
        ...     x_prior=jnp.zeros(1),
        ...     P_prior=jnp.eye(1),
        ...     Q_noise=0.05 * jnp.eye(1),
        ...     R_noise=0.2 * jnp.eye(1),
        ... )
    """
    if solver is None:
        solver = optx.LBFGS(rtol=1e-6, atol=1e-6)

    def fn(xs, args):
        us_, ys_, x_prior_, P_prior_, Q_noise_, R_noise_ = args
        return mhe_objective(
            f,
            h,
            xs,
            us_,
            ys_,
            x_prior_,
            P_prior_,
            Q_noise_,
            R_noise_,
            params,
            extra_cost,
        )

    args = (us, ys, x_prior, P_prior, Q_noise, R_noise)
    sol = optx.minimise(
        fn, solver, xs_init, args=args, max_steps=max_steps, throw=False
    )
    xs_opt = sol.value
    return MHEResult(
        xs=xs_opt,
        x_hat=xs_opt[-1],
        final_cost=fn(xs_opt, args),
        solver_converged=sol.result == optx.RESULTS.successful,
    )

mhe_objective(f: Callable[..., Array], h: Callable[..., Array], xs: Array, us: Array, ys: Array, x_prior: Array, P_prior: Array, Q_noise: Array, R_noise: Array, params: object | None = None, extra_cost: Callable[..., Array] | None = None) -> Array

Evaluate a fixed-window moving-horizon-estimation objective.

This is the first MHE primitive in Contrax: a pure cost function, not a solver. The candidate state trajectory is explicit in xs, so callers can optimize it with their preferred JAX-native solver while keeping fixed horizon shapes.

The objective is a weighted nonlinear least-squares cost:

  • arrival cost on xs[0] - x_prior
  • process cost on xs[k + 1] - f(xs[k], us[k], params)
  • measurement cost on ys[k] - h(xs[k], params)
  • optional user extra_cost(xs, us, ys, params)

Parameters:

Name Type Description Default
f Callable[..., Array]

Transition function. Called as f(x, u) when params=None, and f(x, u, params) otherwise.

required
h Callable[..., Array]

Observation function. Called as h(x) when params=None, and h(x, params) otherwise.

required
xs Array

Candidate state trajectory. Shape: (T + 1, n).

required
us Array

Input sequence. Shape: (T, m).

required
ys Array

Measurement sequence aligned with xs. Shape: (T + 1, p).

required
x_prior Array

Prior state for the start of the window. Shape: (n,).

required
P_prior Array

Arrival covariance. Shape: (n, n).

required
Q_noise Array

Process noise covariance. Shape: (n, n).

required
R_noise Array

Measurement noise covariance. Shape: (p, p).

required
params object | None

Optional parameters passed to f, h, and extra_cost.

None
extra_cost Callable[..., Array] | None

Optional callable adding additional smooth costs. Called as extra_cost(xs, us, ys) when params=None, and extra_cost(xs, us, ys, params) otherwise.

None

Returns:

Name Type Description
Array Array

Scalar MHE cost.

Source code in contrax/_mhe.py
def mhe_objective(
    f: Callable[..., Array],
    h: Callable[..., Array],
    xs: Array,
    us: Array,
    ys: Array,
    x_prior: Array,
    P_prior: Array,
    Q_noise: Array,
    R_noise: Array,
    params: object | None = None,
    extra_cost: Callable[..., Array] | None = None,
) -> Array:
    """Evaluate a fixed-window moving-horizon-estimation objective.

    This is the first MHE primitive in Contrax: a pure cost function, not a
    solver. The candidate state trajectory is explicit in `xs`, so callers can
    optimize it with their preferred JAX-native solver while keeping fixed
    horizon shapes.

    The objective is a weighted nonlinear least-squares cost:

    - arrival cost on `xs[0] - x_prior`
    - process cost on `xs[k + 1] - f(xs[k], us[k], params)`
    - measurement cost on `ys[k] - h(xs[k], params)`
    - optional user `extra_cost(xs, us, ys, params)`

    Args:
        f: Transition function. Called as `f(x, u)` when `params=None`, and
            `f(x, u, params)` otherwise.
        h: Observation function. Called as `h(x)` when `params=None`, and
            `h(x, params)` otherwise.
        xs: Candidate state trajectory. Shape: `(T + 1, n)`.
        us: Input sequence. Shape: `(T, m)`.
        ys: Measurement sequence aligned with `xs`. Shape: `(T + 1, p)`.
        x_prior: Prior state for the start of the window. Shape: `(n,)`.
        P_prior: Arrival covariance. Shape: `(n, n)`.
        Q_noise: Process noise covariance. Shape: `(n, n)`.
        R_noise: Measurement noise covariance. Shape: `(p, p)`.
        params: Optional parameters passed to `f`, `h`, and `extra_cost`.
        extra_cost: Optional callable adding additional smooth costs.
            Called as `extra_cost(xs, us, ys)` when `params=None`, and
            `extra_cost(xs, us, ys, params)` otherwise.

    Returns:
        Array: Scalar MHE cost.
    """
    arrival_residual = xs[0] - x_prior
    x_pred = jax.vmap(lambda x, u: _call_transition(f, x, u, params))(
        xs[:-1],
        us,
    )
    process_residuals = xs[1:] - x_pred
    y_pred = jax.vmap(lambda x: _call_observation(h, x, params))(xs)
    measurement_residuals = ys - y_pred

    cost = (
        _weighted_sum_squares(arrival_residual[None], P_prior)
        + _weighted_sum_squares(process_residuals, Q_noise)
        + _weighted_sum_squares(measurement_residuals, R_noise)
    )
    if extra_cost is not None:
        if params is None:
            cost = cost + extra_cost(xs, us, ys)
        else:
            cost = cost + extra_cost(xs, us, ys, params)
    return cost

mhe_warm_start(xs: Array, *, transition: Callable[[Array, Array], Array] | None = None, terminal_input: Array | None = None) -> Array

Shift an MHE trajectory guess forward by one step.

The default behavior repeats the previous terminal state. If a transition and terminal_input are supplied, the terminal guess is propagated with that model instead. This is a generic rolling-window warm-start helper, not a full arrival-update policy.

Source code in contrax/_mhe.py
def mhe_warm_start(
    xs: Array,
    *,
    transition: Callable[[Array, Array], Array] | None = None,
    terminal_input: Array | None = None,
) -> Array:
    """Shift an MHE trajectory guess forward by one step.

    The default behavior repeats the previous terminal state. If a transition
    and `terminal_input` are supplied, the terminal guess is propagated with
    that model instead. This is a generic rolling-window warm-start helper,
    not a full arrival-update policy.
    """
    xs = jnp.asarray(xs)
    if xs.ndim < 2:
        raise ValueError("mhe_warm_start() expects a state trajectory with time axis.")
    tail = xs[-1] if transition is None else transition(xs[-1], terminal_input)
    return jnp.concatenate([xs[1:], tail[None]], axis=0)

rts(sys: DiscLTI, result: KalmanResult, Q_noise: Array) -> RTSResult

Run a Rauch-Tung-Striebel smoother on filtered Kalman results.

Applies the standard backward smoothing pass to the output of kalman() using the same process model and process noise covariance.

This is an offline smoothing primitive rather than an online filter. Q_noise should match the covariance used in the forward kalman() pass; for deterministic dynamics, pass zeros rather than omitting the term.

Parameters:

Name Type Description Default
sys DiscLTI

Discrete-time linear system used in the forward filter.

required
result KalmanResult

Output of kalman().

required
Q_noise Array

Process noise covariance used by the forward filter.

required

Returns:

Type Description
RTSResult

RTSResult: A bundle containing smoothed state means and covariances.

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys = cx.dss(
...     jnp.array([[1.0, 0.1], [0.0, 1.0]]),
...     jnp.array([[0.0], [0.1]]),
...     jnp.eye(2),
...     jnp.zeros((2, 1)),
...     dt=0.1,
... )
>>> filtered = cx.kalman(
...     sys,
...     Q_noise=1e-3 * jnp.eye(2),
...     R_noise=1e-2 * jnp.eye(2),
...     ys=jnp.zeros((20, 2)),
... )
>>> smoothed = cx.rts(sys, filtered, Q_noise=1e-3 * jnp.eye(2))
Source code in contrax/_kalman.py
def rts(
    sys: DiscLTI,
    result: KalmanResult,
    Q_noise: Array,
) -> RTSResult:
    """Run a Rauch-Tung-Striebel smoother on filtered Kalman results.

    Applies the standard backward smoothing pass to the output of `kalman()`
    using the same process model and process noise covariance.

    This is an offline smoothing primitive rather than an online filter.
    `Q_noise` should match the covariance used in the forward `kalman()` pass;
    for deterministic dynamics, pass zeros rather than omitting the term.

    Args:
        sys: Discrete-time linear system used in the forward filter.
        result: Output of `kalman()`.
        Q_noise: Process noise covariance used by the forward filter.

    Returns:
        [RTSResult][contrax.types.RTSResult]: A bundle containing smoothed
            state means and covariances.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> sys = cx.dss(
        ...     jnp.array([[1.0, 0.1], [0.0, 1.0]]),
        ...     jnp.array([[0.0], [0.1]]),
        ...     jnp.eye(2),
        ...     jnp.zeros((2, 1)),
        ...     dt=0.1,
        ... )
        >>> filtered = cx.kalman(
        ...     sys,
        ...     Q_noise=1e-3 * jnp.eye(2),
        ...     R_noise=1e-2 * jnp.eye(2),
        ...     ys=jnp.zeros((20, 2)),
        ... )
        >>> smoothed = cx.rts(sys, filtered, Q_noise=1e-3 * jnp.eye(2))
    """
    x_hats = result.x_hat  # (T, n)
    Ps = result.P  # (T, n, n)

    def smooth_step(carry, inputs):
        x_s, P_s = carry
        x_f, P_f = inputs

        # Predicted mean/cov from forward pass
        x_pred = sys.A @ x_f
        P_pred = sys.A @ P_f @ sys.A.T + Q_noise

        G = jnp.linalg.solve(
            P_pred.T, sys.A @ P_f.T
        ).T  # smoother gain: P_f A^T P_pred^{-1}
        x_s_new = x_f + G @ (x_s - x_pred)
        P_s_new = P_f + G @ (P_s - P_pred) @ G.T
        return (x_s_new, P_s_new), (x_s_new, P_s_new)

    # Initialise smoother at last filtered estimate
    init = (x_hats[-1], Ps[-1])
    # Scan backwards over t = T-2, ..., 0
    _, (x_smooth_rev, P_smooth_rev) = jax.lax.scan(
        smooth_step, init, (x_hats[:-1][::-1], Ps[:-1][::-1])
    )
    # Prepend the last (unsmoothed) estimate and reverse
    x_smooth = jnp.concatenate([x_smooth_rev[::-1], x_hats[-1:]], axis=0)
    P_smooth = jnp.concatenate([P_smooth_rev[::-1], Ps[-1:]], axis=0)
    return RTSResult(x_smooth=x_smooth, P_smooth=P_smooth)

smoother_diagnostics(smoothed: RTSResult, filtered: KalmanResult | UKFResult) -> SmootherDiagnostics

Compute health diagnostics for an RTS or UKS smoother result.

A numerically healthy smoother satisfies P_smooth ≤ P_filtered in the PSD sense at every step. This function checks that guarantee and surfaces the worst-case violation so callers can detect breakdown before trusting smoothed trajectories.

Parameters:

Name Type Description Default
smoothed RTSResult

Output of rts() or uks().

required
filtered KalmanResult | UKFResult

The forward filter result used to produce smoothed. Accepts both KalmanResult and UKFResult.

required

Returns:

Type Description
SmootherDiagnostics

SmootherDiagnostics: Diagnostic bundle. min_covariance_reduction < 0 is the primary health flag.

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys = cx.dss(
...     jnp.array([[0.9]]), jnp.zeros((1, 1)),
...     jnp.array([[1.0]]), jnp.zeros((1, 1)), dt=1.0,
... )
>>> filtered = cx.kalman(sys, 1e-3 * jnp.eye(1), 1e-2 * jnp.eye(1),
...                      jnp.zeros((10, 1)))
>>> smoothed = cx.rts(sys, filtered, Q_noise=1e-3 * jnp.eye(1))
>>> diag = cx.smoother_diagnostics(smoothed, filtered)
Source code in contrax/_estimation_diagnostics.py
def smoother_diagnostics(
    smoothed: RTSResult,
    filtered: KalmanResult | UKFResult,
) -> SmootherDiagnostics:
    """Compute health diagnostics for an RTS or UKS smoother result.

    A numerically healthy smoother satisfies ``P_smooth ≤ P_filtered`` in the
    PSD sense at every step. This function checks that guarantee and surfaces
    the worst-case violation so callers can detect breakdown before trusting
    smoothed trajectories.

    Args:
        smoothed: Output of `rts()` or `uks()`.
        filtered: The forward filter result used to produce ``smoothed``.
            Accepts both `KalmanResult` and `UKFResult`.

    Returns:
        [SmootherDiagnostics][contrax.types.SmootherDiagnostics]: Diagnostic
            bundle. ``min_covariance_reduction < 0`` is the primary health flag.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> sys = cx.dss(
        ...     jnp.array([[0.9]]), jnp.zeros((1, 1)),
        ...     jnp.array([[1.0]]), jnp.zeros((1, 1)), dt=1.0,
        ... )
        >>> filtered = cx.kalman(sys, 1e-3 * jnp.eye(1), 1e-2 * jnp.eye(1),
        ...                      jnp.zeros((10, 1)))
        >>> smoothed = cx.rts(sys, filtered, Q_noise=1e-3 * jnp.eye(1))
        >>> diag = cx.smoother_diagnostics(smoothed, filtered)
    """
    x_smooth = smoothed.x_smooth  # (T, n)
    P_smooth = smoothed.P_smooth  # (T, n, n)
    x_filt = filtered.x_hat  # (T, n)
    P_filt = filtered.P  # (T, n, n)

    def _sym(M: Array) -> Array:
        return (M + M.T) / 2

    # Min eigenvalue of (P_f - P_s): >= 0 means smoother reduced uncertainty.
    cov_reduction = jax.vmap(
        lambda Pf, Ps: jnp.min(jnp.linalg.eigvalsh(_sym(Pf) - _sym(Ps)))
    )(P_filt, P_smooth)

    state_corrections = jnp.linalg.norm(x_smooth - x_filt, axis=-1)

    smoothed_min_eig = jax.vmap(lambda Ps: jnp.min(jnp.linalg.eigvalsh(_sym(Ps))))(
        P_smooth
    )

    nonfinite = jnp.logical_not(
        jnp.all(jnp.isfinite(x_smooth)) & jnp.all(jnp.isfinite(P_smooth))
    )

    return SmootherDiagnostics(
        covariance_reduction=cov_reduction,
        min_covariance_reduction=jnp.min(cov_reduction),
        state_corrections=state_corrections,
        max_state_correction=jnp.max(state_corrections),
        smoothed_min_eigenvalue=smoothed_min_eig,
        nonfinite=nonfinite,
    )

soft_quadratic_penalty(residuals: Array, weight: Array) -> Array

Return a quadratic soft penalty with matrix or scalar weighting.

This helper is intended for optional MHE soft costs such as state-envelope penalties or terminal-state regularization. It stays small on purpose: callers define the residuals they care about and use this helper for the weighted quadratic part.

Source code in contrax/_mhe.py
def soft_quadratic_penalty(residuals: Array, weight: Array) -> Array:
    """Return a quadratic soft penalty with matrix or scalar weighting.

    This helper is intended for optional MHE soft costs such as state-envelope
    penalties or terminal-state regularization. It stays small on purpose:
    callers define the residuals they care about and use this helper for the
    weighted quadratic part.
    """
    residuals = jnp.asarray(residuals)
    weight = jnp.asarray(weight)
    if weight.ndim == 0:
        return weight * jnp.sum(residuals**2)
    return _weighted_sum_squares(residuals, weight)

ukf(model_or_f: DynamicsLike, Q_noise: Array, R_noise: Array, ys: Array, us: Array, x0: Array, P0: Array, *, observation: ObservationFn | None = None, alpha: float = 1.0, beta: float = 2.0, kappa: float = 0.0) -> UKFResult

Run an unscented Kalman filter for nonlinear dynamics and observations.

Uses a deterministic sigma-point transform instead of local Jacobians to propagate means and covariances through the nonlinear transition and observation models.

This is a baseline nonlinear recursive estimation path that is often more robust than ekf() when the observation model is strongly curved. The current implementation assumes additive process and measurement noise and keeps the public API deliberately small.

Parameters:

Name Type Description Default
model_or_f DynamicsLike

Nonlinear system model or dynamics function.

required
Q_noise Array

Process noise covariance. Shape: (n, n).

required
R_noise Array

Measurement noise covariance. Shape: (p, p).

required
ys Array

Measurement sequence. Shape: (T, p).

required
us Array

Input sequence. Shape: (T, m).

required
x0 Array

Initial mean. Shape: (n,).

required
P0 Array

Initial covariance. Shape: (n, n).

required
observation ObservationFn | None

Observation function h(x) for plain dynamics callables. Omit this when passing a system model with its own output map.

None
alpha float

Sigma-point spread parameter. The default is intentionally wider than the very small textbook setting so the baseline implementation stays numerically well-behaved on the current small-system focus.

1.0
beta float

Prior-distribution parameter. 2.0 is standard for Gaussian priors.

2.0
kappa float

Secondary sigma-point scaling parameter.

0.0

Returns:

Type Description
UKFResult

UKFResult: A bundle containing filtered means/covariances plus prediction, innovation, and likelihood intermediates useful for diagnostics and smoothing.

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> def f(x, u):
...     return jnp.array([x[0] + 0.1 * x[1], x[1] + u[0]])
>>> def h(x):
...     return jnp.array([x[0] ** 2])
>>> result = cx.ukf(
...     f,
...     Q_noise=1e-3 * jnp.eye(2),
...     R_noise=1e-2 * jnp.eye(1),
...     ys=jnp.zeros((10, 1)),
...     us=jnp.zeros((10, 1)),
...     x0=jnp.ones(2),
...     P0=jnp.eye(2),
...     observation=h,
... )
Source code in contrax/_ukf.py
def ukf(
    model_or_f: DynamicsLike,
    Q_noise: Array,
    R_noise: Array,
    ys: Array,
    us: Array,
    x0: Array,
    P0: Array,
    *,
    observation: ObservationFn | None = None,
    alpha: float = 1.0,
    beta: float = 2.0,
    kappa: float = 0.0,
) -> UKFResult:
    """Run an unscented Kalman filter for nonlinear dynamics and observations.

    Uses a deterministic sigma-point transform instead of local Jacobians to
    propagate means and covariances through the nonlinear transition and
    observation models.

    This is a baseline nonlinear recursive estimation path that is often more
    robust than `ekf()` when the observation model is strongly curved. The
    current implementation assumes additive process and measurement noise and
    keeps the public API deliberately small.

    Args:
        model_or_f: Nonlinear system model or dynamics function.
        Q_noise: Process noise covariance. Shape: `(n, n)`.
        R_noise: Measurement noise covariance. Shape: `(p, p)`.
        ys: Measurement sequence. Shape: `(T, p)`.
        us: Input sequence. Shape: `(T, m)`.
        x0: Initial mean. Shape: `(n,)`.
        P0: Initial covariance. Shape: `(n, n)`.
        observation: Observation function `h(x)` for plain dynamics callables.
            Omit this when passing a system model with its own output map.
        alpha: Sigma-point spread parameter. The default is intentionally wider
            than the very small textbook setting so the baseline implementation
            stays numerically well-behaved on the current small-system focus.
        beta: Prior-distribution parameter. `2.0` is standard for Gaussian
            priors.
        kappa: Secondary sigma-point scaling parameter.

    Returns:
        [UKFResult][contrax.types.UKFResult]: A bundle containing filtered
            means/covariances plus prediction, innovation, and likelihood
            intermediates useful for diagnostics and smoothing.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> def f(x, u):
        ...     return jnp.array([x[0] + 0.1 * x[1], x[1] + u[0]])
        >>> def h(x):
        ...     return jnp.array([x[0] ** 2])
        >>> result = cx.ukf(
        ...     f,
        ...     Q_noise=1e-3 * jnp.eye(2),
        ...     R_noise=1e-2 * jnp.eye(1),
        ...     ys=jnp.zeros((10, 1)),
        ...     us=jnp.zeros((10, 1)),
        ...     x0=jnp.ones(2),
        ...     P0=jnp.eye(2),
        ...     observation=h,
        ... )
    """

    if _is_system_model(model_or_f):
        if observation is not None:
            raise ValueError("ukf(sys, ...) does not accept observation=.")
        obs = model_or_f
    else:
        if observation is None:
            raise ValueError(
                "ukf(f, ..., observation=...) requires an observation function."
            )
        obs = observation

    f = _coerce_dynamics(model_or_f)
    h = _coerce_observation(
        obs,
        None if _is_system_model(obs) else obs,
    )
    dt = _system_dt(model_or_f)
    ts = jnp.arange(ys.shape[0], dtype=x0.dtype)
    if dt is not None:
        ts = ts * dt

    # Update-then-predict: x0/P0 is the prior on x_0.
    def update_predict(carry, inputs):
        x_prior, P_prior = carry
        y, u, t = inputs

        # Update: unscented transform through h using the current prior.
        sigma_obs, Wm_obs, Wc_obs = _sigma_points(x_prior, P_prior, alpha, beta, kappa)
        y_sigma = jax.vmap(lambda s: h(s, u, t))(sigma_obs)
        y_pred = jnp.einsum("i,ij->j", Wm_obs, y_sigma)
        y_centered = y_sigma - y_pred
        S = _weighted_covariance(y_centered, Wc_obs) + R_noise
        S = _symmetrize(S)
        x_centered = sigma_obs - x_prior
        Pxy = _cross_covariance(x_centered, y_centered, Wc_obs)
        K = jnp.linalg.solve(S.T, Pxy.T).T
        innov = y - y_pred
        x_post = x_prior + K @ innov
        P_post = _symmetrize(P_prior - K @ S @ K.T)
        log_likelihood = _gaussian_log_likelihood(innov, S)

        # Predict: unscented transform through f from the posterior.
        sigma, Wm, Wc = _sigma_points(x_post, P_post, alpha, beta, kappa)
        sigma_pred = jax.vmap(lambda s: f(s, u, t))(sigma)
        x_next = jnp.einsum("i,ij->j", Wm, sigma_pred)
        sigma_pred_centered = sigma_pred - x_next
        P_next = _weighted_covariance(sigma_pred_centered, Wc) + Q_noise
        P_next = _symmetrize(P_next)
        P_cross = _cross_covariance(sigma - x_post, sigma_pred_centered, Wc)

        return (x_next, P_next), (
            x_post,
            P_post,
            innov,
            y_pred,
            S,
            log_likelihood,
            x_next,
            P_next,
            P_cross,
        )

    (
        _,
        (
            x_hats,
            Ps,
            innovations,
            predicted_measurements,
            innovation_covariances,
            log_likelihood_terms,
            predicted_state_means,
            predicted_state_covariances,
            transition_cross_covariances,
        ),
    ) = jax.lax.scan(update_predict, (x0, _symmetrize(P0)), (ys, us, ts))
    return UKFResult(
        x_hat=x_hats,
        P=Ps,
        innovations=innovations,
        predicted_measurements=predicted_measurements,
        innovation_covariances=innovation_covariances,
        log_likelihood_terms=log_likelihood_terms,
        predicted_state_means=predicted_state_means,
        predicted_state_covariances=predicted_state_covariances,
        transition_cross_covariances=transition_cross_covariances,
    )

ukf_diagnostics(result: UKFResult) -> tuple[InnovationDiagnostics, LikelihoodDiagnostics]

Build the standard diagnostics pair for a UKF result.

Source code in contrax/_estimation_diagnostics.py
def ukf_diagnostics(
    result: UKFResult,
) -> tuple[InnovationDiagnostics, LikelihoodDiagnostics]:
    """Build the standard diagnostics pair for a UKF result."""
    return (
        innovation_diagnostics(result.innovations, result.innovation_covariances),
        likelihood_diagnostics(result.log_likelihood_terms),
    )

uks(model_or_f: DynamicsLike, result: UKFResult, Q_noise: Array, us: Array, *, alpha: float = 1.0, beta: float = 2.0, kappa: float = 0.0) -> RTSResult

Run an unscented Rauch-Tung-Striebel-style smoother.

Applies a backward smoothing pass to the output of ukf() using the same nonlinear dynamics model, process noise covariance, and input sequence.

Source code in contrax/_ukf.py
def uks(
    model_or_f: DynamicsLike,
    result: UKFResult,
    Q_noise: Array,
    us: Array,
    *,
    alpha: float = 1.0,
    beta: float = 2.0,
    kappa: float = 0.0,
) -> RTSResult:
    """Run an unscented Rauch-Tung-Striebel-style smoother.

    Applies a backward smoothing pass to the output of `ukf()` using the same
    nonlinear dynamics model, process noise covariance, and input sequence.
    """
    del model_or_f, Q_noise, alpha, beta, kappa
    x_hats = result.x_hat
    Ps = result.P
    if us.shape[0] != x_hats.shape[0]:
        raise ValueError(
            "uks() requires us to have the same leading length as the filtered result."
        )

    def smooth_step(carry, inputs):
        x_s, P_s = carry
        x_f, P_f, x_pred, P_pred, P_cross = inputs
        G = jnp.linalg.solve(P_pred.T, P_cross.T).T
        x_s_new = x_f + G @ (x_s - x_pred)
        P_s_new = _symmetrize(P_f + G @ (P_s - P_pred) @ G.T)
        return (x_s_new, P_s_new), (x_s_new, P_s_new)

    init = (x_hats[-1], Ps[-1])
    _, (x_smooth_rev, P_smooth_rev) = jax.lax.scan(
        smooth_step,
        init,
        (
            x_hats[:-1][::-1],
            Ps[:-1][::-1],
            result.predicted_state_means[:-1][::-1],
            result.predicted_state_covariances[:-1][::-1],
            result.transition_cross_covariances[:-1][::-1],
        ),
    )
    x_smooth = jnp.concatenate([x_smooth_rev[::-1], x_hats[-1:]], axis=0)
    P_smooth = jnp.concatenate([P_smooth_rev[::-1], Ps[-1:]], axis=0)
    return RTSResult(x_smooth=x_smooth, P_smooth=P_smooth)