Kalman Filtering¶
This tutorial shows a standard estimation workflow on a small discrete system. We will define a state-space model, run a linear Kalman filter on a measurement sequence, smooth the result with RTS, and finish with checks that tell us whether the estimate is behaving plausibly.
Runnable script: examples/kalman_filtering.py
This tutorial sits primarily in the Estimation API, with Systems providing the underlying discrete model.
The model and measurement equations are
with Gaussian noise covariances Q_noise and R_noise. The batch filter uses
the update-first convention:
and RTS adds the backward pass
Define The Estimation Problem¶
The model is a constant-velocity discrete system with position-only measurements. That is a familiar place to start: the hidden state contains both position and velocity, but the sensor only reports position.
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import contrax as cx
SYS = cx.dss(
A=jnp.array([[1.0, 0.1], [0.0, 1.0]]),
B=jnp.zeros((2, 1)),
C=jnp.array([[1.0, 0.0]]),
D=jnp.zeros((1, 1)),
dt=0.1,
)
Q_NOISE = jnp.diag(jnp.array([1e-4, 5e-4]))
R_NOISE = jnp.array([[2.5e-3]])
YS = jnp.array(
[
[0.92],
[0.95],
[0.99],
[1.03],
[1.06],
[1.05],
[1.02],
[1.00],
[0.98],
[1.01],
]
)
X0 = jnp.array([0.0, 0.0])
P0 = jnp.eye(2)
Run The Filter And Smoother¶
kalman() runs the batch discrete-time filter as a jax.lax.scan, so the full
pass stays in JAX. rts() then applies an offline backward smoothing pass to
the filtered trajectory using the same process model and process noise.
def run_example():
filtered = cx.kalman(
SYS,
Q_noise=Q_NOISE,
R_noise=R_NOISE,
ys=YS,
x0=X0,
P0=P0,
)
smoothed = cx.rts(SYS, filtered, Q_noise=Q_NOISE)
final_measurement = float(YS[-1, 0])
final_filtered = float(filtered.x_hat[-1, 0])
final_smoothed = float(smoothed.x_smooth[-1, 0])
midpoint_index = 4
midpoint_filtered = float(filtered.x_hat[midpoint_index, 0])
midpoint_smoothed = float(smoothed.x_smooth[midpoint_index, 0])
innovation_norm = float(jnp.linalg.norm(filtered.innovations))
assert filtered.x_hat.shape == (10, 2)
assert filtered.P.shape == (10, 2, 2)
assert smoothed.x_smooth.shape == (10, 2)
assert abs(final_filtered - final_measurement) < 0.1
assert innovation_norm < 1.5
return {
"final_measurement": final_measurement,
"final_filtered_position": final_filtered,
"final_smoothed_position": final_smoothed,
"midpoint_filtered_position": midpoint_filtered,
"midpoint_smoothed_position": midpoint_smoothed,
"final_velocity": float(filtered.x_hat[-1, 1]),
"innovation_norm": innovation_norm,
"filtered_cov_trace": float(jnp.trace(filtered.P[-1])),
}
In Contrax, the filter treats (x0, P0) as the prior on x_0 and updates
first with y_0. That convention matters when you compare the batch result to
the one-step kalman_update() and kalman_step() helpers in runtime loops.
What The Script Prints¶
Running examples/kalman_filtering.py prints a compact summary of the final
estimate:
Kalman filtering and RTS smoothing
final measurement = 1.010000
mid filtered position = 1.060639
mid smoothed position = 1.003812
final filtered position = 1.024387
final smoothed position = 1.024387
final filtered velocity = 0.055978
innovation norm = 0.932813
final covariance trace = 0.007035
The exact numbers are not the point. The useful checks are whether the filter produces a state estimate consistent with the measurements, keeps covariance bounded, and shows the role of smoothing clearly: the interior smoothed state can differ from the filtered one, while the terminal RTS state matches the final filtered state by construction.
Validate The Result¶
For a first estimation workflow, the checks should be easy to interpret:
- the filtered state shape should match the measurement horizon and state size
- the final filtered position should be close to the last measurement
- the covariance trace should stay finite and reasonably small
- the innovation sequence should not blow up
The runnable example mirrors those checks with assertions so this tutorial stays tied to executable code rather than drifting into static documentation.
Where To Go Next¶
- Getting started for the fastest route into the library
- Estimation API reference for batch and one-step filter helpers
- Handle missing measurements for runtime loop behavior when sensors drop out
- Build an MHE objective for the optimization-based estimation sibling workflow
- Estimation pipelines for how filters, smoothers, and MHE fit together
- JAX transform contract for scan,
jit, and differentiation expectations