Skip to content

Nonlinear MHE with Constrained Parameter Estimation

This example estimates both state and a physical parameter from noisy measurements using Moving Horizon Estimation. The parameter (damping coefficient) is kept positive throughout optimization by representing it in an unconstrained raw form and applying positive_softplus() inside the dynamics.

Runnable script: examples/nonlinear_mhe.py

Problem Setup

A damped pendulum is driven by a small random torque:

\[ \ddot{\theta} = -\frac{g}{l} \sin\theta - b\,\dot{\theta} + \tau, \qquad y_k = \theta(t_k) + v_k. \]

The damping coefficient \(b > 0\) is unknown. The estimator must recover both the state \((\theta, \dot\theta)\) and \(b\) jointly from angle measurements alone.

Augmented State and Constrained Parameterization

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp

import contrax as cx

DT = 0.05  # sample period (s)
WINDOW = 12  # MHE window length
G_ACCEL = 9.81  # gravitational acceleration (m/s²)
L = 1.0  # pendulum length (m), assumed known
DAMPING_TRUE = 0.3  # true damping coefficient (kg·m²/s)
N_AUG = 3  # augmented state dimension: [theta, omega, b_raw]
def aug_dynamics(x_aug, u):
    """Euler step for augmented state z = [theta, omega, b_raw].

    The damping parameter b_raw is treated as a slowly-varying state:
    its dynamics are an identity (random-walk model). The physical damping
    b = positive_softplus(b_raw) is always positive.
    """
    theta, omega, b_raw = x_aug[0], x_aug[1], x_aug[2]
    b = cx.positive_softplus(b_raw)
    torque = u[0]
    alpha = -(G_ACCEL / L) * jnp.sin(theta) - b * omega + torque
    theta_next = theta + DT * omega
    omega_next = omega + DT * alpha
    return jnp.array([theta_next, omega_next, b_raw])  # b_raw unchanged


def aug_observation(x_aug):
    """Observe angle only from augmented state."""
    return x_aug[:1]

The key design pattern is parameter augmentation: \(b\) is lifted into the state vector as \(b_\text{raw}\) with identity dynamics (random-walk model), and the physical value is recovered inside the model via

\[ b = \text{softplus}(b_\text{raw}) > 0. \]

This keeps the MHE objective unconstrained everywhere. The optimizer explores all of \(\mathbb{R}\) in \(b_\text{raw}\) space and the constraint \(b > 0\) is enforced structurally, not by an inequality constraint or a projection step.

Solving MHE

def run_mhe(xs_phys, ys, us):
    """Solve MHE over the first window and report results."""
    # Noise covariances in augmented state space
    Q_noise = jnp.diag(jnp.array([1e-4, 1e-3, 1e-5]))  # small drift on b_raw
    R_noise = jnp.array([[0.04]])

    # Prior: start near the first measurement; b_raw=0 → b≈0.69 (soft overestimate)
    x_prior = jnp.array([float(ys[0, 0]), 0.0, 0.0])
    P_prior = jnp.diag(jnp.array([0.1, 0.5, 2.0]))

    ys_win = ys[: WINDOW + 1]  # (WINDOW+1, 1)
    us_win = us[:WINDOW]  # (WINDOW, 1)

    # Warm-start: replicate prior across window
    xs_init = jnp.tile(x_prior, (WINDOW + 1, 1))

    result = cx.mhe(
        aug_dynamics,
        aug_observation,
        xs_init,
        us_win,
        ys_win,
        x_prior,
        P_prior,
        Q_noise,
        R_noise,
        max_steps=512,
    )
    return result

mhe() minimizes mhe_objective() over the full window trajectory using LBFGS. The Kalman-style prior term (x_prior, P_prior) regularizes the start of the window, and Q_noise controls how much the parameter is allowed to drift across the horizon.

What the Script Prints

Nonlinear MHE — damped pendulum with parameter estimation
  converged           = False (LBFGS may report False before tolerance; check cost)
  final cost          = 5.6983e+00
  damping: true=0.300  estimated=0.680
  angle at window end:
    true      = 0.6481 rad
    estimated = 0.6552 rad

All assertions passed.

The angle estimate is close to ground truth. The damping estimate is biased because a 12-step window with a slowly-varying random-walk prior is a short identification horizon for a single scalar parameter — longer windows or a tighter prior on \(b_\text{raw}\) dynamics improve convergence. The solver_converged=False flag reflects LBFGS tolerance reporting, not a failed solve; checking final_cost directly is more informative.