Skip to content

FOH Estimation: EKF with First-Order-Hold Inputs

This example filters a nonlinear oscillator whose input varies continuously between sample steps. Rather than assuming the input is constant over each sample interval (zero-order hold), it models the piecewise-linear profile exactly using first-order-hold (FOH) interpolation.

Runnable script: examples/foh_estimation.py

Problem Setup

The plant is a Van der Pol oscillator driven by a scalar input:

\[ \dot{x}_1 = x_2, \qquad \dot{x}_2 = \mu (1 - x_1^2) x_2 - x_1 + u(t) \]

The sensor observes the first state only:

\[ y_k = x_1(t_k) + v_k. \]

The input \(u(t)\) is sinusoidal and sampled on the same grid as the measurements. Within each sample interval, \(u(t)\) is piecewise-linear rather than constant, so FOH gives a tighter discrete model than ZOH.

Setup and Model

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import numpy as np

import contrax as cx

# Van der Pol oscillator: ẋ = [x2, μ(1-x1²)x2 - x1 + u]
MU = 1.0
DT = 0.1  # sample period
T_STEPS = 50  # number of steps
def vdp_dynamics(t, x, u):
    """Continuous-time Van der Pol dynamics with additive scalar input."""
    x1, x2 = x[0], x[1]
    dx1 = x2
    dx2 = MU * (1 - x1**2) * x2 - x1 + u[0]
    return jnp.array([dx1, dx2])


def vdp_observation(t, x, u):
    """Observe the first state (position) only."""
    return x[:1]


nl = cx.nonlinear_system(
    vdp_dynamics, vdp_observation, dt=None, state_dim=2, input_dim=1, output_dim=1
)

# Build a discrete transition model via FOH integration
discrete_foh = cx.sample_system(nl, DT, input_interpolation="foh")

sample_system(..., input_interpolation="foh") builds a discrete transition model that integrates the continuous dynamics under a linearly interpolated input \(u(t) = (1 - \tau) u_k + \tau u_{k+1}\) over \(\tau \in [0, 1]\).

Each step of this discrete model takes a stacked input pair u_pair = [[u_k], [u_{k+1}]] with shape (2, m). foh_inputs(us) converts a raw (T, m) sequence into the (T, 2, m) pairs the model expects.

Simulating and Filtering

def generate_trajectory(seed: int = 0):
    """Simulate ground-truth trajectory and noisy measurements."""
    key = jax.random.PRNGKey(seed)

    # Piecewise-linear input: slowly varying
    t_axis = jnp.linspace(0.0, T_STEPS * DT, T_STEPS)
    us_raw = 0.3 * jnp.sin(2 * jnp.pi * 0.5 * t_axis)[:, None]  # (T, 1)
    us_foh = cx.foh_inputs(us_raw)  # (T, 2, 1)

    # Simulate true trajectory using FOH discrete model
    x0_true = jnp.array([0.5, 0.0])
    xs_true = cx.rollout(
        lambda x, u: discrete_foh.dynamics(0.0, x, u), x0_true, us_foh
    )  # (T+1, 2)

    # Noisy measurements of first state only
    R_noise = jnp.array([[0.05]])
    key, subkey = jax.random.split(key)
    noise = jax.random.normal(subkey, (T_STEPS + 1, 1)) * jnp.sqrt(R_noise[0, 0])
    ys = xs_true[:, :1] + noise  # (T+1, 1)

    return xs_true, ys, us_raw, us_foh
def run_ekf(xs_true, ys, us_foh):
    """Run an online EKF using FOH-sampled discrete model."""
    Q_noise = jnp.diag(jnp.array([1e-4, 1e-3]))
    R_noise = jnp.array([[0.05]])

    def h(x):
        return x[:1]

    def f_foh(x, u_pair):
        return discrete_foh.dynamics(0.0, x, u_pair)

    x = ys[0, :1]  # initialise from first measurement
    x = jnp.concatenate([x, jnp.zeros(1)])
    P = jnp.eye(2) * 0.5

    xs_est = [x]
    for k in range(T_STEPS):
        x, P = cx.ekf_predict(f_foh, x, P, us_foh[k], Q_noise)
        x, P, _ = cx.ekf_update(h, x, P, ys[k + 1], R_noise)
        xs_est.append(x)

    return jnp.stack(xs_est)  # (T+1, 2)

The EKF loop uses the one-step helpers ekf_predict and ekf_update. The FOH-discrete model is passed as the dynamics callable, so the filter linearizes it automatically via JAX Jacobians.

What the Script Prints

Continuous-model EKF with FOH input interpolation
  steps              = 50
  position RMSE      = 0.1204
  velocity RMSE      = 0.1999
  true  final state  = [0.77012901 2.5961284 ]
  estim final state  = [0.75520658 2.62896228]

All assertions passed.