FOH Estimation: EKF with First-Order-Hold Inputs¶
This example filters a nonlinear oscillator whose input varies continuously between sample steps. Rather than assuming the input is constant over each sample interval (zero-order hold), it models the piecewise-linear profile exactly using first-order-hold (FOH) interpolation.
Runnable script: examples/foh_estimation.py
Problem Setup¶
The plant is a Van der Pol oscillator driven by a scalar input:
The sensor observes the first state only:
The input \(u(t)\) is sinusoidal and sampled on the same grid as the measurements. Within each sample interval, \(u(t)\) is piecewise-linear rather than constant, so FOH gives a tighter discrete model than ZOH.
Setup and Model¶
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np
import contrax as cx
# Van der Pol oscillator: ẋ = [x2, μ(1-x1²)x2 - x1 + u]
MU = 1.0
DT = 0.1 # sample period
T_STEPS = 50 # number of steps
def vdp_dynamics(t, x, u):
"""Continuous-time Van der Pol dynamics with additive scalar input."""
x1, x2 = x[0], x[1]
dx1 = x2
dx2 = MU * (1 - x1**2) * x2 - x1 + u[0]
return jnp.array([dx1, dx2])
def vdp_observation(t, x, u):
"""Observe the first state (position) only."""
return x[:1]
nl = cx.nonlinear_system(
vdp_dynamics, vdp_observation, dt=None, state_dim=2, input_dim=1, output_dim=1
)
# Build a discrete transition model via FOH integration
discrete_foh = cx.sample_system(nl, DT, input_interpolation="foh")
sample_system(..., input_interpolation="foh") builds a discrete transition
model that integrates the continuous dynamics under a linearly interpolated
input \(u(t) = (1 - \tau) u_k + \tau u_{k+1}\) over \(\tau \in [0, 1]\).
Each step of this discrete model takes a stacked input pair
u_pair = [[u_k], [u_{k+1}]] with shape (2, m). foh_inputs(us) converts
a raw (T, m) sequence into the (T, 2, m) pairs the model expects.
Simulating and Filtering¶
def generate_trajectory(seed: int = 0):
"""Simulate ground-truth trajectory and noisy measurements."""
key = jax.random.PRNGKey(seed)
# Piecewise-linear input: slowly varying
t_axis = jnp.linspace(0.0, T_STEPS * DT, T_STEPS)
us_raw = 0.3 * jnp.sin(2 * jnp.pi * 0.5 * t_axis)[:, None] # (T, 1)
us_foh = cx.foh_inputs(us_raw) # (T, 2, 1)
# Simulate true trajectory using FOH discrete model
x0_true = jnp.array([0.5, 0.0])
xs_true = cx.rollout(
lambda x, u: discrete_foh.dynamics(0.0, x, u), x0_true, us_foh
) # (T+1, 2)
# Noisy measurements of first state only
R_noise = jnp.array([[0.05]])
key, subkey = jax.random.split(key)
noise = jax.random.normal(subkey, (T_STEPS + 1, 1)) * jnp.sqrt(R_noise[0, 0])
ys = xs_true[:, :1] + noise # (T+1, 1)
return xs_true, ys, us_raw, us_foh
def run_ekf(xs_true, ys, us_foh):
"""Run an online EKF using FOH-sampled discrete model."""
Q_noise = jnp.diag(jnp.array([1e-4, 1e-3]))
R_noise = jnp.array([[0.05]])
def h(x):
return x[:1]
def f_foh(x, u_pair):
return discrete_foh.dynamics(0.0, x, u_pair)
x = ys[0, :1] # initialise from first measurement
x = jnp.concatenate([x, jnp.zeros(1)])
P = jnp.eye(2) * 0.5
xs_est = [x]
for k in range(T_STEPS):
x, P = cx.ekf_predict(f_foh, x, P, us_foh[k], Q_noise)
x, P, _ = cx.ekf_update(h, x, P, ys[k + 1], R_noise)
xs_est.append(x)
return jnp.stack(xs_est) # (T+1, 2)
The EKF loop uses the one-step helpers ekf_predict and ekf_update. The
FOH-discrete model is passed as the dynamics callable, so the filter linearizes
it automatically via JAX Jacobians.
What the Script Prints¶
Continuous-model EKF with FOH input interpolation
steps = 50
position RMSE = 0.1204
velocity RMSE = 0.1999
true final state = [0.77012901 2.5961284 ]
estim final state = [0.75520658 2.62896228]
All assertions passed.
Related Pages¶
- Continuous nonlinear estimation for a similar workflow using UKF/UKS instead of EKF
- Estimation API
- Simulation API for
sample_systemandfoh_inputs