Skip to content

Kalman Filtering

This tutorial shows a standard estimation workflow on a small discrete system. We will define a state-space model, run a linear Kalman filter on a measurement sequence, smooth the result with RTS, and finish with checks that tell us whether the estimate is behaving plausibly.

Estimation pipeline from model and measurements through filtering, smoothing, and MHE
The tutorial path: start from a discrete state-space model, run the batch Kalman filter, then use RTS to revisit the same sequence with future information available.

Runnable script: examples/kalman_filtering.py

This tutorial sits primarily in the Estimation API, with Systems providing the underlying discrete model.

The model and measurement equations are

\[ x_{k+1} = A x_k + w_k, \qquad y_k = C x_k + v_k \]

with Gaussian noise covariances Q_noise and R_noise. The batch filter uses the update-first convention:

\[ \hat{x}_k^+ = \hat{x}_k^- + K_k\bigl(y_k - C\hat{x}_k^-\bigr), \qquad \hat{x}_{k+1}^- = A\hat{x}_k^+ \]

and RTS adds the backward pass

\[ \hat{x}_k^{\,s} = \hat{x}_k^{\,f} + G_k\bigl(\hat{x}_{k+1}^{\,s} - \hat{x}_{k+1}^{-}\bigr) \]

Define The Estimation Problem

The model is a constant-velocity discrete system with position-only measurements. That is a familiar place to start: the hidden state contains both position and velocity, but the sensor only reports position.

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp

import contrax as cx

SYS = cx.dss(
    A=jnp.array([[1.0, 0.1], [0.0, 1.0]]),
    B=jnp.zeros((2, 1)),
    C=jnp.array([[1.0, 0.0]]),
    D=jnp.zeros((1, 1)),
    dt=0.1,
)
Q_NOISE = jnp.diag(jnp.array([1e-4, 5e-4]))
R_NOISE = jnp.array([[2.5e-3]])
YS = jnp.array(
    [
        [0.92],
        [0.95],
        [0.99],
        [1.03],
        [1.06],
        [1.05],
        [1.02],
        [1.00],
        [0.98],
        [1.01],
    ]
)
X0 = jnp.array([0.0, 0.0])
P0 = jnp.eye(2)

Run The Filter And Smoother

kalman() runs the batch discrete-time filter as a jax.lax.scan, so the full pass stays in JAX. rts() then applies an offline backward smoothing pass to the filtered trajectory using the same process model and process noise.

def run_example():
    filtered = cx.kalman(
        SYS,
        Q_noise=Q_NOISE,
        R_noise=R_NOISE,
        ys=YS,
        x0=X0,
        P0=P0,
    )
    smoothed = cx.rts(SYS, filtered, Q_noise=Q_NOISE)

    final_measurement = float(YS[-1, 0])
    final_filtered = float(filtered.x_hat[-1, 0])
    final_smoothed = float(smoothed.x_smooth[-1, 0])
    midpoint_index = 4
    midpoint_filtered = float(filtered.x_hat[midpoint_index, 0])
    midpoint_smoothed = float(smoothed.x_smooth[midpoint_index, 0])
    innovation_norm = float(jnp.linalg.norm(filtered.innovations))

    assert filtered.x_hat.shape == (10, 2)
    assert filtered.P.shape == (10, 2, 2)
    assert smoothed.x_smooth.shape == (10, 2)
    assert abs(final_filtered - final_measurement) < 0.1
    assert innovation_norm < 1.5

    return {
        "final_measurement": final_measurement,
        "final_filtered_position": final_filtered,
        "final_smoothed_position": final_smoothed,
        "midpoint_filtered_position": midpoint_filtered,
        "midpoint_smoothed_position": midpoint_smoothed,
        "final_velocity": float(filtered.x_hat[-1, 1]),
        "innovation_norm": innovation_norm,
        "filtered_cov_trace": float(jnp.trace(filtered.P[-1])),
    }

In Contrax, the filter treats (x0, P0) as the prior on x_0 and updates first with y_0. That convention matters when you compare the batch result to the one-step kalman_update() and kalman_step() helpers in runtime loops.

What The Script Prints

Running examples/kalman_filtering.py prints a compact summary of the final estimate:

Kalman filtering and RTS smoothing
final measurement        = 1.010000
mid filtered position    = 1.060639
mid smoothed position    = 1.003812
final filtered position  = 1.024387
final smoothed position  = 1.024387
final filtered velocity  = 0.055978
innovation norm          = 0.932813
final covariance trace   = 0.007035

The exact numbers are not the point. The useful checks are whether the filter produces a state estimate consistent with the measurements, keeps covariance bounded, and shows the role of smoothing clearly: the interior smoothed state can differ from the filtered one, while the terminal RTS state matches the final filtered state by construction.

Validate The Result

For a first estimation workflow, the checks should be easy to interpret:

  • the filtered state shape should match the measurement horizon and state size
  • the final filtered position should be close to the last measurement
  • the covariance trace should stay finite and reasonably small
  • the innovation sequence should not blow up

The runnable example mirrors those checks with assertions so this tutorial stays tied to executable code rather than drifting into static documentation.

Where To Go Next