Skip to content

Structured Nonlinear Estimation

This example shows what the structured nonlinear layer is for in practice: keep the model in a port-Hamiltonian form, observe only part of the state, and still run the same recursive-estimation workflow as any other nonlinear system.

The central estimation question here is simple: can a position-only sensor and a structured model recover hidden momentum cleanly enough that smoothing adds visible value?

Runnable script: examples/structured_nonlinear_estimation.py

Problem Setup

The state is (q, p), where q is configuration and p is momentum. The system is a damped, forced port-Hamiltonian oscillator with Hamiltonian

\[ H(q, p) = \tfrac{1}{2}(1.2q^2 + p^2). \]

The structure map J is the canonical skew-symmetric interconnection, the dissipation map R(x) damps momentum, and the input map injects force only into the momentum channel. The measurement is position-only:

\[ y_k = q_k + v_k. \]

So the hidden quantity is exactly the one the structured dynamics care most about: momentum.

Build The Structured Model

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import jax.random as jr

import contrax as cx

DT = 0.1
DURATION = 3.0
NUM_STEPS = int(DURATION / DT)
NOISE_KEY = jr.PRNGKey(17)
X0_TRUE = jnp.array([1.0, -0.15])
X0_EST = jnp.array([0.75, 0.2])
P0 = 0.12 * jnp.eye(2)
Q_NOISE = 2e-4 * jnp.eye(2)
R_NOISE = jnp.array([[3e-3]])
MEASUREMENT_STD = jnp.sqrt(R_NOISE[0, 0])
SAMPLE_TIMES = jnp.arange(NUM_STEPS + 1, dtype=jnp.float64) * DT
INPUT_TIMES = SAMPLE_TIMES[:-1]
INPUTS = (0.18 * jnp.sin(1.1 * INPUT_TIMES) + 0.08 * jnp.cos(0.55 * INPUT_TIMES))[
    :, None
]


def hamiltonian(x):
    q, p = x
    return 0.5 * (1.2 * q**2 + p**2)


OBSERVE_POSITION = cx.block_observation((1, 1), (0,))

SYS_CONT = cx.phs_system(
    hamiltonian,
    R=lambda x: jnp.array([[0.0, 0.0], [0.0, 0.14]], dtype=x.dtype),
    G=lambda x: jnp.array([[0.0], [1.0]], dtype=x.dtype),
    output=OBSERVE_POSITION,
    state_dim=2,
    input_dim=1,
    output_dim=1,
)
SYS_DISC = cx.sample_system(SYS_CONT, DT)

The useful parts of the public surface are all visible here:

  • phs_system() defines the continuous structured model
  • block_observation() builds the position-only measurement map without custom indexing glue
  • sample_system() turns the continuous structured dynamics into a discrete estimation model

Run Filtering, Smoothing, Diagnostics, And Structure Checks

def run_example():
    x_true = simulate_truth()
    ys = sample_measurements(x_true)

    filtered = cx.ukf(
        SYS_DISC,
        Q_noise=Q_NOISE,
        R_noise=R_NOISE,
        ys=ys,
        us=INPUTS,
        x0=X0_EST,
        P0=P0,
        alpha=0.5,
    )
    smoothed = cx.uks(SYS_DISC, filtered, Q_noise=Q_NOISE, us=INPUTS, alpha=0.5)
    innovation_diag, likelihood_diag = cx.ukf_diagnostics(filtered)
    structure_diag = cx.phs_diagnostics(SYS_CONT, X0_TRUE, INPUTS[0])
    rmse = summarize_rmse(filtered, smoothed, x_true)

    assert filtered.x_hat.shape == (NUM_STEPS, 2)
    assert filtered.predicted_measurements.shape == (NUM_STEPS, 1)
    assert smoothed.x_smooth.shape == (NUM_STEPS, 2)
    assert rmse["smoothed_q_rmse"] < rmse["filtered_q_rmse"]
    assert rmse["smoothed_p_rmse"] < rmse["filtered_p_rmse"]

    return {
        "times": INPUT_TIMES,
        "inputs": INPUTS[:, 0],
        "measurements": ys[:, 0],
        "true_q": x_true[:, 0],
        "filtered_q": filtered.x_hat[:, 0],
        "smoothed_q": smoothed.x_smooth[:, 0],
        "true_p": x_true[:, 1],
        "filtered_p": filtered.x_hat[:, 1],
        "smoothed_p": smoothed.x_smooth[:, 1],
        **rmse,
        "mean_nis": float(innovation_diag.mean_nis),
        "max_condition_number": float(
            innovation_diag.max_innovation_cov_condition_number
        ),
        "total_log_likelihood": float(likelihood_diag.total_log_likelihood),
        "skew_symmetry_residual": float(structure_diag.skew_symmetry_error),
        "dissipation_min_eigenvalue": float(structure_diag.min_dissipation_eigenvalue),
    }

This example goes slightly beyond “run a filter and print a state.” It also uses phs_diagnostics() to evaluate the local structure contract on the same model that drives the estimator. The measurement samples use seeded Gaussian noise so the run stays reproducible while still reading like a real estimation problem.

Structured nonlinear estimation example showing input forcing and the true, filtered, and smoothed position and momentum trajectories
Position-only observation of a PHS oscillator: the filter tracks configuration well from the measured channel, while smoothing substantially improves the hidden momentum estimate.

The middle and bottom panels tell the story. Position is the measured state, so filtered and smoothed q stay close to the truth. Momentum is not measured at all, so p is the stronger structural test.

What The Script Prints

Running examples/structured_nonlinear_estimation.py prints:

Structured nonlinear estimation
filtered q rmse         = 0.020113
smoothed q rmse         = 0.013961
filtered p rmse         = 0.104952
smoothed p rmse         = 0.017463
mean NIS                = 0.780241
max innovation cond     = 1.000000
total log likelihood    = 39.698811
skew residual           = 0.000000e+00
min dissipation eig     = 0.000000

The useful checks are:

  • smoothing lowers both position and momentum RMSE
  • the innovation covariance remains well-conditioned
  • the structure diagnostics stay consistent with the model definition