Estimation¶
Estimation is the public namespace for recursive filtering, smoothing, and the first optimization-based estimation primitives in Contrax.
The namespace includes:
kalman()andkalman_gain()for linear Gaussian filtering and steady-state estimator designkalman_predict(),kalman_update(),kalman_step()for one-step runtime loopsekf()andukf()for nonlinear filteringekf_predict(),ekf_update(),ekf_step()for one-step nonlinear loopsrts()anduks()for offline smoothinginnovation_diagnostics(),likelihood_diagnostics(),ukf_diagnostics(), andsmoother_diagnostics()for routine estimator-health summariesmhe_objective(),mhe(),mhe_warm_start(), andsoft_quadratic_penalty()for fixed-window optimization-based estimationpositive_exp(),positive_softplus(),spd_from_cholesky_raw(), anddiagonal_spd()for lightweight constrained parameterization
Minimal Example¶
import jax.numpy as jnp
import contrax as cx
sys = cx.dss(
jnp.array([[1.0, 0.1], [0.0, 1.0]]),
jnp.array([[0.0], [0.1]]),
jnp.eye(2),
jnp.zeros((2, 1)),
dt=0.1,
)
result = cx.kalman(
sys,
Q_noise=1e-3 * jnp.eye(2),
R_noise=1e-2 * jnp.eye(2),
ys=jnp.zeros((20, 2)),
)
Conventions¶
ys: measurement sequence with shape(T, p)x0: prior mean onx_0P0: prior covariance onx_0KalmanResult.x_hat: filtered state sequence with shape(T, n)KalmanResult.P: filtered covariance sequence with shape(T, n, n)KalmanResult.innovations: measurement residual sequence with shape(T, p)forkalman()andekf()UKFResult.predicted_measurements: pre-update measurement means with shape(T, p)forukf()UKFResult.innovation_covariances: innovation covariance sequence with shape(T, p, p)forukf()
Contrax uses the update-first batch convention for kalman(), ekf(), and
ukf(): (x0, P0) is the prior on x_0, and the first scan step updates
with y_0.
Estimator Equations¶
Contrax uses an update-first batch convention: the pair (x0, P0) is the
prior on x_0, the first step updates with y_0, and only then predicts the
prior on x_1.
For linear discrete models, the state and observation maps are
with process noise covariance Q and measurement noise covariance R.
Linear Kalman Filter¶
kalman() implements the standard discrete Gaussian filter:
Extended Kalman Filter¶
For nonlinear models
ekf() linearizes the transition and observation maps with JAX Jacobians:
If you use the one-step ekf_update(..., num_iter>1) helper, the measurement
map is relinearized around the current update iterate while keeping the
prediction covariance fixed.
Unscented Kalman Filter¶
ukf() avoids local Jacobians and instead propagates sigma points. Given
dimension \(n\) and scaling parameter \(\lambda\),
At the update stage, sigma points from the current prior \((\hat{x}_k^-, P_k^-)\) are pushed through the observation map:
For prediction, sigma points from \((\hat{x}_k^+, P_k^+)\) are pushed through the transition map:
Smoothers¶
rts() is the linear Rauch-Tung-Striebel smoother on top of
KalmanResult:
uks() applies the same backward shape to the unscented filter, but replaces
\(A P_k^+\) with an unscented cross-covariance:
Moving-Horizon Estimation¶
mhe_objective() and mhe() expose the fixed-window optimization form:
This is the optimization sibling of the recursive filters: the same model and noise assumptions appear, but they are expressed as an explicit horizon cost.
For rolling-window use:
mhe_warm_start(xs, ...)shifts a previous solution forward by one step to build the next initial guesssoft_quadratic_penalty(residuals, weight)is a small helper for soft state, terminal, or envelope penalties insideextra_cost
One-Step Helper Signatures¶
The one-step helpers follow the order: model, state (x, P), runtime data
(u/y), noise matrices (Q/R).
# predict: advances state by one step
x_pred, P_pred = cx.ekf_predict(model, x, P, u, Q_noise)
x_pred, P_pred = cx.kalman_predict(sys, x, P, Q_noise, u=None) # u optional
# update: fuses a new measurement
x, P, innov = cx.ekf_update(model, x_pred, P_pred, y, R_noise)
x, P, innov = cx.kalman_update(sys, x_pred, P_pred, y, R_noise, u=None)
# combined predict-update step
x, P, innov = cx.ekf_step(model, x, P, u, y, Q_noise, R_noise, observation=h)
x, P, innov = cx.kalman_step(sys, x, P, y, Q_noise, R_noise, u=None)
has_measurement=False can be passed to all step helpers to skip the update
(useful for streams with missing samples inside jax.lax.scan).
Model Inputs¶
For nonlinear estimation, the main docs path is reusable system models such as
NonlinearSystem or
PHSSystem.
That keeps simulation, linearization, and estimation on one shared model object instead of redefining the same dynamics in several places.
The intended public workflow is:
sys_c = cx.nonlinear_system(f_continuous, output=h, dt=None)
sys_d = cx.sample_system(sys_c, dt=0.1)
result = cx.ukf(sys_d, Q_noise, R_noise, ys, us, x0, P0)
smoothed = cx.uks(sys_d, result, Q_noise, us)
ukf() returns
UKFResult rather than
KalmanResult because sigma-point filtering
exposes additional public intermediates: predicted measurements, innovation
covariances, per-step likelihood terms, and the one-step prediction quantities
used by uks().
For continuous dynamics with discrete observations, sample_system() is the
main library bridge. It integrates the continuous vector field over one sample
interval and returns a discrete-time NonlinearSystem that can be passed
directly to ekf() or ukf().
Transform Behavior¶
Use the one-step helpers when measurements arrive in a runtime loop or when the
filter needs to live inside a larger jax.lax.scan.
The important transform contracts are:
kalman(),ekf(), andukf()are batch scans over fixed-shape sequences- the one-step helpers expose missing-measurement handling without Python branching on traced values
mhe_objective()is a pure cost function over an explicit candidate trajectory, which makes it suitable for JAX-native optimization loops
Numerical Notes¶
kalman_gain() is the estimator dual of discrete LQR and returns the
measurement-update gain used by kalman_update().
rts() requires the same Q_noise used by the forward kalman() pass so the
smoother uses the same process model. Pass zeros for deterministic dynamics.
mhe_objective() is the lowest-level optimization-based estimator surface in
the library. mhe() is a thin solver wrapper, not a full nonlinear
programming framework.
Diagnostics¶
Contrax now includes a lightweight diagnostics layer for routine estimation health checks:
innovation_diagnostics(innovations, innovation_covariances)computes per-step normalized innovation squared (NIS), innovation norms, and covariance conditioning summarieslikelihood_diagnostics(log_likelihood_terms)summarizes innovation-form Gaussian log-likelihood termsukf_diagnostics(result)builds the standard innovation and likelihood summaries directly from aUKFResultinnovation_rms(result)gives a quick residual-scale summary for eitherKalmanResultorUKFResultsmoother_diagnostics(smoothed, filtered)checks whether the smoother satisfied its covariance-reduction guarantee (P_smooth ≤ P_filtered) and reports the worst-case violation, state correction magnitudes, and PSD health of the smoothed covariances. AcceptsRTSResultpaired with eitherKalmanResultorUKFResult.
These helpers are intentionally honest and lightweight. They are good for filter tuning, regression checks, and notebook-free workflow validation, but they are not a substitute for a full statistical consistency study.
Parameterization Helpers¶
Contrax also exposes a small constrained-parameterization layer for estimation and tuning workflows:
positive_exp(raw)andpositive_softplus(raw)map unconstrained values to positive onesspd_from_cholesky_raw(raw)maps an unconstrained square matrix to an SPD matrix through a lower-triangular Cholesky factordiagonal_spd(raw_diagonal)builds a diagonal SPD matrix from unconstrained diagonal parameters
These are intentionally low-level building blocks rather than a parameter framework. They exist so downstream repos do not need to reinvent the same positivity and covariance-parameterization transforms.
Related Pages¶
- Types for
KalmanResult,UKFResult,KalmanGainResult,RTSResult,InnovationDiagnostics,SmootherDiagnostics, andMHEResult - Kalman filtering for an end-to-end workflow
- Handle missing measurements for a task-oriented runtime recipe
- Estimation pipelines for how the batch, one-step, smoothing, and MHE pieces fit together
contrax.estimation
¶
contrax.estimation — Kalman, EKF, UKF, MHE, and RTS helpers.
ekf(model_or_f: DynamicsLike, Q_noise: Array, R_noise: Array, ys: Array, us: Array, x0: Array, P0: Array, *, observation: ObservationFn | None = None) -> KalmanResult
¶
Run an extended Kalman filter for nonlinear dynamics and observations.
Uses automatic Jacobians of the transition and observation functions rather than requiring a hand-derived local linearization.
This is the main nonlinear recursive estimation primitive in the current library. It works well with JAX-defined model functions, but the API is intentionally simple and does not yet expose multiple EKF variants.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_or_f
|
DynamicsLike
|
Nonlinear system model or dynamics function. |
required |
Q_noise
|
Array
|
Process noise covariance. Shape: |
required |
R_noise
|
Array
|
Measurement noise covariance. Shape: |
required |
ys
|
Array
|
Measurement sequence. Shape: |
required |
us
|
Array
|
Input sequence. Shape: |
required |
x0
|
Array
|
Initial mean. Shape: |
required |
P0
|
Array
|
Initial covariance. Shape: |
required |
observation
|
ObservationFn | None
|
Observation function |
None
|
Returns:
| Type | Description |
|---|---|
KalmanResult
|
KalmanResult: A bundle containing filtered means, covariances, and innovations. |
Examples:
>>> import jax.numpy as jnp
>>> import contrax as cx
>>> def f(x, u):
... return jnp.array([x[0] + 0.1 * x[1], x[1] + u[0]])
>>> def h(x):
... return x[:1]
>>> result = cx.ekf(
... f,
... Q_noise=1e-3 * jnp.eye(2),
... R_noise=1e-2 * jnp.eye(1),
... ys=jnp.zeros((10, 1)),
... us=jnp.zeros((10, 1)),
... x0=jnp.zeros(2),
... P0=jnp.eye(2),
... observation=h,
... )
Source code in contrax/_ekf.py
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 | |
ekf_predict(model_or_f: DynamicsLike, x: Array, P: Array, u: Array, Q_noise: Array, *, t: Array | float = 0.0) -> tuple[Array, Array]
¶
Predict one extended Kalman filter step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_or_f
|
DynamicsLike
|
Nonlinear system model or dynamics function mapping
|
required |
x
|
Array
|
Filtered mean at the previous time step. Shape: |
required |
P
|
Array
|
Filtered covariance at the previous time step. Shape: |
required |
u
|
Array
|
Current input. Shape: |
required |
Q_noise
|
Array
|
Process noise covariance. Shape: |
required |
t
|
Array | float
|
Current sample time. |
0.0
|
Returns:
| Type | Description |
|---|---|
tuple[Array, Array]
|
Tuple |
Source code in contrax/_ekf.py
ekf_step(model_or_f: DynamicsLike, x: Array, P: Array, u: Array, y: Array, Q_noise: Array, R_noise: Array, *, t: Array | float = 0.0, has_measurement: Array | bool = True, num_iter: int = 1, observation: ObservationFn | None = None) -> tuple[Array, Array, Array]
¶
Run one predict-update step of an extended Kalman filter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_or_f
|
DynamicsLike
|
Nonlinear system model or dynamics function. |
required |
x
|
Array
|
Filtered mean at the previous time step. Shape: |
required |
P
|
Array
|
Filtered covariance at the previous time step. Shape: |
required |
u
|
Array
|
Current input. Shape: |
required |
y
|
Array
|
Measurement vector. Shape: |
required |
Q_noise
|
Array
|
Process noise covariance. Shape: |
required |
R_noise
|
Array
|
Measurement noise covariance. Shape: |
required |
has_measurement
|
Array | bool
|
Whether to apply the measurement update. |
True
|
num_iter
|
int
|
Number of iterated-EKF measurement linearization passes. |
1
|
observation
|
ObservationFn | None
|
Observation function |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Array, Array, Array]
|
Tuple |
Source code in contrax/_ekf.py
ekf_update(model_or_h: ObservationLike, x_pred: Array, P_pred: Array, y: Array, R_noise: Array, *, u: Array | None = None, t: Array | float = 0.0, has_measurement: Array | bool = True, num_iter: int = 1) -> tuple[Array, Array, Array]
¶
Update one extended Kalman filter step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_or_h
|
ObservationLike
|
Nonlinear system model or observation function. |
required |
x_pred
|
Array
|
Predicted mean. Shape: |
required |
P_pred
|
Array
|
Predicted covariance. Shape: |
required |
y
|
Array
|
Measurement vector. Shape: |
required |
R_noise
|
Array
|
Measurement noise covariance. Shape: |
required |
u
|
Array | None
|
Current input used by the observation model. Shape: |
None
|
t
|
Array | float
|
Current sample time. |
0.0
|
has_measurement
|
Array | bool
|
Whether to apply the measurement update. |
True
|
num_iter
|
int
|
Number of iterated-EKF measurement linearization passes.
|
1
|
Returns:
| Type | Description |
|---|---|
Array
|
Tuple |
Array
|
and final measurement residual. Missing-measurement steps return the |
Array
|
prediction and a zero innovation. |
Source code in contrax/_ekf.py
innovation_diagnostics(innovations: Array, innovation_covariances: Array) -> InnovationDiagnostics
¶
Summarize innovation magnitude and covariance conditioning.
The returned NIS values are the per-step normalized innovation squared
terms v_k^T S_k^{-1} v_k. They are useful for routine filter health
inspection, but Contrax does not present them as a formal standalone
consistency certificate.
Source code in contrax/_estimation_diagnostics.py
innovation_rms(result: KalmanResult | UKFResult) -> Array
¶
kalman(sys: DiscLTI, Q_noise: Array, R_noise: Array, ys: Array, x0: Array | None = None, P0: Array | None = None) -> KalmanResult
¶
Run a linear discrete-time Kalman filter.
Filters a measurement sequence for a
DiscLTI system using the standard
predict-update recursion expressed as a pure lax.scan.
This is the baseline linear-Gaussian estimation path in Contrax. Innovation gains are computed with linear solves rather than explicit inverses, and the scan structure makes small differentiable covariance-tuning loops practical.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sys
|
DiscLTI
|
Discrete-time linear system. |
required |
Q_noise
|
Array
|
Process noise covariance. Shape: |
required |
R_noise
|
Array
|
Measurement noise covariance. Shape: |
required |
ys
|
Array
|
Measurement sequence. Shape: |
required |
x0
|
Array | None
|
Optional initial mean. Shape: |
None
|
P0
|
Array | None
|
Optional initial covariance. Shape: |
None
|
Returns:
| Type | Description |
|---|---|
KalmanResult
|
KalmanResult: A bundle containing
filtered means |
Examples:
>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys = cx.dss(
... jnp.array([[1.0, 0.1], [0.0, 1.0]]),
... jnp.array([[0.0], [0.1]]),
... jnp.eye(2),
... jnp.zeros((2, 1)),
... dt=0.1,
... )
>>> result = cx.kalman(
... sys,
... Q_noise=1e-3 * jnp.eye(2),
... R_noise=1e-2 * jnp.eye(2),
... ys=jnp.zeros((20, 2)),
... )
Source code in contrax/_kalman.py
kalman_gain(sys: DiscLTI, Q_noise: Array, R_noise: Array) -> KalmanGainResult
¶
Design a steady-state Kalman filter gain for a discrete LTI system.
This is the estimator dual of discrete LQR. It solves the DARE for
(A.T, C.T, Q_noise, R_noise) and returns the measurement-update gain used
by kalman_update().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sys
|
DiscLTI
|
Discrete-time linear system. |
required |
Q_noise
|
Array
|
Process noise covariance. Shape: |
required |
R_noise
|
Array
|
Measurement noise covariance. Shape: |
required |
Returns:
| Type | Description |
|---|---|
KalmanGainResult
|
KalmanGainResult: A bundle with
update gain |
Examples:
>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys = cx.dss(
... jnp.array([[1.0, 0.1], [0.0, 1.0]]),
... jnp.array([[0.0], [0.1]]),
... jnp.eye(2),
... jnp.zeros((2, 1)),
... dt=0.1,
... )
>>> design = cx.kalman_gain(sys, 1e-3 * jnp.eye(2), 1e-2 * jnp.eye(2))
Source code in contrax/_kalman.py
kalman_predict(sys: DiscLTI, x: Array, P: Array, Q_noise: Array, u: Array | None = None) -> tuple[Array, Array]
¶
Predict one linear Kalman filter step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sys
|
DiscLTI
|
Discrete-time linear system. |
required |
x
|
Array
|
Filtered mean at the previous time step. Shape: |
required |
P
|
Array
|
Filtered covariance at the previous time step. Shape: |
required |
Q_noise
|
Array
|
Process noise covariance. Shape: |
required |
u
|
Array | None
|
Optional input for the transition. Shape: |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Array, Array]
|
Tuple |
Source code in contrax/_kalman.py
kalman_step(sys: DiscLTI, x: Array, P: Array, y: Array, Q_noise: Array, R_noise: Array, u: Array | None = None, *, has_measurement: Array | bool = True) -> tuple[Array, Array, Array]
¶
Run one predict-update step of a linear Kalman filter.
This is the online counterpart to kalman(). It is a pure array function,
so service loops can call it directly and batch workflows can place it
inside jax.lax.scan.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sys
|
DiscLTI
|
Discrete-time linear system. |
required |
x
|
Array
|
Filtered mean at the previous time step. Shape: |
required |
P
|
Array
|
Filtered covariance at the previous time step. Shape: |
required |
y
|
Array
|
Measurement vector. Shape: |
required |
Q_noise
|
Array
|
Process noise covariance. Shape: |
required |
R_noise
|
Array
|
Measurement noise covariance. Shape: |
required |
u
|
Array | None
|
Optional input. Shape: |
None
|
has_measurement
|
Array | bool
|
Whether to apply the measurement update. |
True
|
Returns:
| Type | Description |
|---|---|
tuple[Array, Array, Array]
|
Tuple |
Source code in contrax/_kalman.py
kalman_update(sys: DiscLTI, x_pred: Array, P_pred: Array, y: Array, R_noise: Array, u: Array | None = None, *, has_measurement: Array | bool = True) -> tuple[Array, Array, Array]
¶
Update one linear Kalman filter step.
has_measurement=False returns the predicted state and covariance without
applying a measurement update. The flag may be a JAX scalar boolean, making
this helper usable inside jax.lax.scan for streams with missing samples.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sys
|
DiscLTI
|
Discrete-time linear system. |
required |
x_pred
|
Array
|
Predicted mean. Shape: |
required |
P_pred
|
Array
|
Predicted covariance. Shape: |
required |
y
|
Array
|
Measurement vector. Shape: |
required |
R_noise
|
Array
|
Measurement noise covariance. Shape: |
required |
u
|
Array | None
|
Optional input for direct-feedthrough measurement models. Shape:
|
None
|
has_measurement
|
Array | bool
|
Whether to apply the measurement update. |
True
|
Returns:
| Type | Description |
|---|---|
Array
|
Tuple |
Array
|
and measurement residual. Missing-measurement steps return a zero |
Array
|
innovation with the same shape as |
Source code in contrax/_kalman.py
likelihood_diagnostics(log_likelihood_terms: Array) -> LikelihoodDiagnostics
¶
Summarize per-step innovation-form log-likelihood terms.
Source code in contrax/_estimation_diagnostics.py
mhe(f: Callable[..., Array], h: Callable[..., Array], xs_init: Array, us: Array, ys: Array, x_prior: Array, P_prior: Array, Q_noise: Array, R_noise: Array, params: object | None = None, extra_cost: Callable[..., Array] | None = None, solver: AbstractMinimiser | None = None, max_steps: int = 256) -> MHEResult
¶
Solve a fixed-window moving-horizon-estimation problem.
Minimises mhe_objective over the candidate state trajectory xs_init
using an optimistix solver. The default solver is
optx.LBFGS(rtol=1e-6, atol=1e-6), which is well-suited to the
deterministic, smooth, fixed-horizon objectives that arise in MHE.
The optimised terminal state x_hat = xs[-1] is the natural input to a
downstream controller or observer. For a linear-Gaussian model, the
full trajectory xs matches the RTS smoother output within solver
tolerance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[..., Array]
|
Transition function. Called as |
required |
h
|
Callable[..., Array]
|
Observation function. Called as |
required |
xs_init
|
Array
|
Initial trajectory guess. Shape: |
required |
us
|
Array
|
Input sequence. Shape: |
required |
ys
|
Array
|
Measurement sequence aligned with |
required |
x_prior
|
Array
|
Prior mean for the start of the window. Shape: |
required |
P_prior
|
Array
|
Arrival covariance. Shape: |
required |
Q_noise
|
Array
|
Process noise covariance. Shape: |
required |
R_noise
|
Array
|
Measurement noise covariance. Shape: |
required |
params
|
object | None
|
Optional parameters passed to |
None
|
extra_cost
|
Callable[..., Array] | None
|
Optional callable adding additional smooth costs.
Called as |
None
|
solver
|
AbstractMinimiser | None
|
Optimistix minimiser. Defaults to
|
None
|
max_steps
|
int
|
Maximum solver iterations. Defaults to 256. |
256
|
Returns:
| Type | Description |
|---|---|
MHEResult
|
MHEResult: Estimated trajectory, terminal state, final cost, and convergence flag. |
Examples:
>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys = cx.dss(
... jnp.array([[0.8]]), jnp.zeros((1, 1)),
... jnp.array([[1.0]]), jnp.zeros((1, 1)), dt=1.0,
... )
>>> f = lambda x, u: sys.A @ x
>>> h = lambda x: sys.C @ x
>>> ys = jnp.array([[0.1], [0.4], [0.6], [0.5], [0.4]])
>>> result = cx.mhe(
... f, h,
... xs_init=jnp.zeros((5, 1)),
... us=jnp.zeros((4, 1)),
... ys=ys,
... x_prior=jnp.zeros(1),
... P_prior=jnp.eye(1),
... Q_noise=0.05 * jnp.eye(1),
... R_noise=0.2 * jnp.eye(1),
... )
Source code in contrax/_mhe.py
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 | |
mhe_objective(f: Callable[..., Array], h: Callable[..., Array], xs: Array, us: Array, ys: Array, x_prior: Array, P_prior: Array, Q_noise: Array, R_noise: Array, params: object | None = None, extra_cost: Callable[..., Array] | None = None) -> Array
¶
Evaluate a fixed-window moving-horizon-estimation objective.
This is the first MHE primitive in Contrax: a pure cost function, not a
solver. The candidate state trajectory is explicit in xs, so callers can
optimize it with their preferred JAX-native solver while keeping fixed
horizon shapes.
The objective is a weighted nonlinear least-squares cost:
- arrival cost on
xs[0] - x_prior - process cost on
xs[k + 1] - f(xs[k], us[k], params) - measurement cost on
ys[k] - h(xs[k], params) - optional user
extra_cost(xs, us, ys, params)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[..., Array]
|
Transition function. Called as |
required |
h
|
Callable[..., Array]
|
Observation function. Called as |
required |
xs
|
Array
|
Candidate state trajectory. Shape: |
required |
us
|
Array
|
Input sequence. Shape: |
required |
ys
|
Array
|
Measurement sequence aligned with |
required |
x_prior
|
Array
|
Prior state for the start of the window. Shape: |
required |
P_prior
|
Array
|
Arrival covariance. Shape: |
required |
Q_noise
|
Array
|
Process noise covariance. Shape: |
required |
R_noise
|
Array
|
Measurement noise covariance. Shape: |
required |
params
|
object | None
|
Optional parameters passed to |
None
|
extra_cost
|
Callable[..., Array] | None
|
Optional callable adding additional smooth costs.
Called as |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
Array |
Array
|
Scalar MHE cost. |
Source code in contrax/_mhe.py
mhe_warm_start(xs: Array, *, transition: Callable[[Array, Array], Array] | None = None, terminal_input: Array | None = None) -> Array
¶
Shift an MHE trajectory guess forward by one step.
The default behavior repeats the previous terminal state. If a transition
and terminal_input are supplied, the terminal guess is propagated with
that model instead. This is a generic rolling-window warm-start helper,
not a full arrival-update policy.
Source code in contrax/_mhe.py
rts(sys: DiscLTI, result: KalmanResult, Q_noise: Array) -> RTSResult
¶
Run a Rauch-Tung-Striebel smoother on filtered Kalman results.
Applies the standard backward smoothing pass to the output of kalman()
using the same process model and process noise covariance.
This is an offline smoothing primitive rather than an online filter.
Q_noise should match the covariance used in the forward kalman() pass;
for deterministic dynamics, pass zeros rather than omitting the term.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sys
|
DiscLTI
|
Discrete-time linear system used in the forward filter. |
required |
result
|
KalmanResult
|
Output of |
required |
Q_noise
|
Array
|
Process noise covariance used by the forward filter. |
required |
Returns:
| Type | Description |
|---|---|
RTSResult
|
RTSResult: A bundle containing smoothed state means and covariances. |
Examples:
>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys = cx.dss(
... jnp.array([[1.0, 0.1], [0.0, 1.0]]),
... jnp.array([[0.0], [0.1]]),
... jnp.eye(2),
... jnp.zeros((2, 1)),
... dt=0.1,
... )
>>> filtered = cx.kalman(
... sys,
... Q_noise=1e-3 * jnp.eye(2),
... R_noise=1e-2 * jnp.eye(2),
... ys=jnp.zeros((20, 2)),
... )
>>> smoothed = cx.rts(sys, filtered, Q_noise=1e-3 * jnp.eye(2))
Source code in contrax/_kalman.py
smoother_diagnostics(smoothed: RTSResult, filtered: KalmanResult | UKFResult) -> SmootherDiagnostics
¶
Compute health diagnostics for an RTS or UKS smoother result.
A numerically healthy smoother satisfies P_smooth ≤ P_filtered in the
PSD sense at every step. This function checks that guarantee and surfaces
the worst-case violation so callers can detect breakdown before trusting
smoothed trajectories.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
smoothed
|
RTSResult
|
Output of |
required |
filtered
|
KalmanResult | UKFResult
|
The forward filter result used to produce |
required |
Returns:
| Type | Description |
|---|---|
SmootherDiagnostics
|
SmootherDiagnostics: Diagnostic
bundle. |
Examples:
>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys = cx.dss(
... jnp.array([[0.9]]), jnp.zeros((1, 1)),
... jnp.array([[1.0]]), jnp.zeros((1, 1)), dt=1.0,
... )
>>> filtered = cx.kalman(sys, 1e-3 * jnp.eye(1), 1e-2 * jnp.eye(1),
... jnp.zeros((10, 1)))
>>> smoothed = cx.rts(sys, filtered, Q_noise=1e-3 * jnp.eye(1))
>>> diag = cx.smoother_diagnostics(smoothed, filtered)
Source code in contrax/_estimation_diagnostics.py
soft_quadratic_penalty(residuals: Array, weight: Array) -> Array
¶
Return a quadratic soft penalty with matrix or scalar weighting.
This helper is intended for optional MHE soft costs such as state-envelope penalties or terminal-state regularization. It stays small on purpose: callers define the residuals they care about and use this helper for the weighted quadratic part.
Source code in contrax/_mhe.py
ukf(model_or_f: DynamicsLike, Q_noise: Array, R_noise: Array, ys: Array, us: Array, x0: Array, P0: Array, *, observation: ObservationFn | None = None, alpha: float = 1.0, beta: float = 2.0, kappa: float = 0.0) -> UKFResult
¶
Run an unscented Kalman filter for nonlinear dynamics and observations.
Uses a deterministic sigma-point transform instead of local Jacobians to propagate means and covariances through the nonlinear transition and observation models.
This is a baseline nonlinear recursive estimation path that is often more
robust than ekf() when the observation model is strongly curved. The
current implementation assumes additive process and measurement noise and
keeps the public API deliberately small.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_or_f
|
DynamicsLike
|
Nonlinear system model or dynamics function. |
required |
Q_noise
|
Array
|
Process noise covariance. Shape: |
required |
R_noise
|
Array
|
Measurement noise covariance. Shape: |
required |
ys
|
Array
|
Measurement sequence. Shape: |
required |
us
|
Array
|
Input sequence. Shape: |
required |
x0
|
Array
|
Initial mean. Shape: |
required |
P0
|
Array
|
Initial covariance. Shape: |
required |
observation
|
ObservationFn | None
|
Observation function |
None
|
alpha
|
float
|
Sigma-point spread parameter. The default is intentionally wider than the very small textbook setting so the baseline implementation stays numerically well-behaved on the current small-system focus. |
1.0
|
beta
|
float
|
Prior-distribution parameter. |
2.0
|
kappa
|
float
|
Secondary sigma-point scaling parameter. |
0.0
|
Returns:
| Type | Description |
|---|---|
UKFResult
|
UKFResult: A bundle containing filtered means/covariances plus prediction, innovation, and likelihood intermediates useful for diagnostics and smoothing. |
Examples:
>>> import jax.numpy as jnp
>>> import contrax as cx
>>> def f(x, u):
... return jnp.array([x[0] + 0.1 * x[1], x[1] + u[0]])
>>> def h(x):
... return jnp.array([x[0] ** 2])
>>> result = cx.ukf(
... f,
... Q_noise=1e-3 * jnp.eye(2),
... R_noise=1e-2 * jnp.eye(1),
... ys=jnp.zeros((10, 1)),
... us=jnp.zeros((10, 1)),
... x0=jnp.ones(2),
... P0=jnp.eye(2),
... observation=h,
... )
Source code in contrax/_ukf.py
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 | |
ukf_diagnostics(result: UKFResult) -> tuple[InnovationDiagnostics, LikelihoodDiagnostics]
¶
Build the standard diagnostics pair for a UKF result.
Source code in contrax/_estimation_diagnostics.py
uks(model_or_f: DynamicsLike, result: UKFResult, Q_noise: Array, us: Array, *, alpha: float = 1.0, beta: float = 2.0, kappa: float = 0.0) -> RTSResult
¶
Run an unscented Rauch-Tung-Striebel-style smoother.
Applies a backward smoothing pass to the output of ukf() using the same
nonlinear dynamics model, process noise covariance, and input sequence.