Skip to content

Systems

Systems is the public modeling namespace in Contrax. It brings together the objects and helpers you use to represent dynamics before you design a controller, run an estimator, or simulate a trajectory.

The namespace includes:

  • continuous and discrete LTI models: ContLTI, DiscLTI, ss(), dss()
  • nonlinear and port-Hamiltonian model objects: NonlinearSystem, PHSSystem, nonlinear_system(), phs_system(), schedule_phs()
  • PHS structure helpers: canonical_J() to define the symplectic structure map, phs_to_ss() to extract the local state-space linearization around an operating point
  • structured-system helpers: partition_state(), block_observation(), block_matrix(), symmetrize_matrix(), project_psd(), phs_diagnostics()
  • discretization and local model construction: c2d(), linearize(), linearize_ss()
  • simple state-space composition: series(), parallel(), plus operator overloads on ContLTI and DiscLTI

Minimal Example

Use the systems namespace when you want the model-building story directly:

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import contrax.systems as cxs

sys_c = cxs.ss(
    jnp.array([[0.0, 1.0], [0.0, 0.0]]),
    jnp.array([[0.0], [1.0]]),
    jnp.eye(2),
    jnp.zeros((2, 1)),
)
sys_d = cxs.c2d(sys_c, dt=0.05)

Canonical Model Forms

LTI And Nonlinear Systems

The LTI and generic nonlinear contracts are the broad modeling surface.

\[ \dot{x} = A x + B u, \qquad y = C x + D u \]
\[ x_{k+1} = A x_k + B u_k, \qquad y_k = C x_k + D u_k \]
\[ \dot{x} = f(t, x, u), \qquad y = h(t, x, u) \]

Port-Hamiltonian Systems

PHSSystem is the structured nonlinear family in Contrax.

\[ \dot{x} = \bigl(J - R(x)\bigr)\nabla H(x) + G(x)u \]

For scheduled coefficients, schedule_phs() binds an exogenous context into that canonical model:

\[ \dot{x} = \bigl(J - R(t, x, \theta(t))\bigr)\nabla H(x) + G(t, x, \theta(t))u \]

Conventions

  • A: state matrix with shape (n, n)
  • B: input matrix with shape (n, m)
  • C: output matrix with shape (p, n)
  • D: feedthrough matrix with shape (p, m)
  • DiscLTI.dt: scalar sample time stored as an array leaf

For nonlinear models, the reusable model-object contract is:

  • dynamics(t, x, u) for the transition or vector field
  • output(t, x, u) for the output or measurement map

Contrax keeps continuous and discrete LTI systems on the same field names so code can move between them with minimal ceremony.

For PHSSystem, the base library contract is intentionally narrower:

  • H(x) is the storage function
  • R(x) is the dissipation map when present
  • G(x) is the input map when present
  • J(x) defaults to the canonical symplectic structure unless you override it
  • schedule_phs() is the extension path when dissipation or input maps depend on observed scheduling context
  • schedule_phs() can also bind scheduled structure maps J(t, x, context) when the structure itself varies with observed context

Structured Helpers

Contrax now includes a small generic helper layer around structured nonlinear models so downstream repos do not need their own mini utility module just to express common PHS patterns.

Use:

  • partition_state(x, block_sizes) to split a structured state into consecutive blocks
  • block_observation(block_sizes, block_indices) to build partial-state output maps from those blocks
  • block_matrix(row_block_sizes, col_block_sizes, blocks) to assemble dense structured input, process, or observation maps from block entries
  • symmetrize_matrix(M) and project_psd(M) when dissipation or covariance matrices pick up small numerical asymmetry or PSD drift
  • phs_diagnostics(sys, x, u) for local skew-symmetry, dissipation, and instantaneous power-balance diagnostics

These are intentionally lightweight helpers, not a large structured-model framework. They are there to make the public PHS layer ergonomic and reusable.

Port-Hamiltonian Structure Helpers

canonical_J(n) returns the \(n \times n\) canonical symplectic structure matrix for even \(n\):

\[ J_c = \begin{bmatrix} 0 & I_{n/2} \\ -I_{n/2} & 0 \end{bmatrix} \]

Use it when building a PHSSystem whose structure map is standard symplectic rather than a custom dissipative one.

phs_to_ss(sys, x_eq, u_eq) linearizes a PHS at an operating point and returns a ContLTI. This is the standard bridge from a structured nonlinear PHS model into the linear control design stack (lqr, place, kalman_gain, etc.):

lti = cx.phs_to_ss(phs_sys, x_eq=jnp.zeros(4), u_eq=jnp.zeros(1))
design = cx.lqr(cx.c2d(lti, dt=0.01), Q, R)

Linearization And Discretization

linearize() and linearize_ss() are the main JAX-native bridge from nonlinear plant code into the linear control stack.

Use:

  • linearize(f, x0, u0) when you only need A and B
  • linearize_ss(f, x0, u0, output=h) when you have plain dynamics and output callables
  • linearize_ss(sys, x0, u0) when you want a full local ContLTI from a reusable NonlinearSystem or PHSSystem

The local linear model is the usual first-order expansion around an operating point:

\[ \delta \dot{x} \approx A\,\delta x + B\,\delta u, \qquad \delta y \approx C\,\delta x + D\,\delta u \]

with

\[ A = \left.\frac{\partial f}{\partial x}\right|_{(x_0, u_0)}, \quad B = \left.\frac{\partial f}{\partial u}\right|_{(x_0, u_0)}, \quad C = \left.\frac{\partial h}{\partial x}\right|_{(x_0, u_0)}, \quad D = \left.\frac{\partial h}{\partial u}\right|_{(x_0, u_0)} \]

Then use c2d() when the downstream workflow is discrete-time design or simulation.

The c2d() methods are:

  • zoh: the stronger path, with a custom VJP around the matrix exponential
  • tustin: a convenience bilinear transform
Pipeline from nonlinear dynamics through linearization, discretization, LQR design, and closed-loop simulation
The main model-preparation path: start with a nonlinear or structured model, build a local ContLTI with linearize_ss(), discretize it with c2d(), then move into controller design and validation.

Interconnection

Simple state-space composition lives in the same public namespace because it is part of the model-building story.

Use:

  • series(sys2, sys1) or sys2 @ sys1 for feedforward composition
  • parallel(sys1, sys2) or sys1 + sys2 for summed-output composition
  • sys1 - sys2 for parallel subtraction

Mixed continuous/discrete interconnections are intentionally unsupported, and discrete interconnections require matching dt.

State ordering is explicit:

  • series(sys2, sys1) puts sys1 states first, then sys2
  • parallel(sys1, sys2) puts sys1 states first, then sys2

Transform Behavior

The system containers themselves are pytrees, so they work naturally with jit and vmap when shapes line up. That matters because the intended Contrax workflows include vmapped linearization, batched controller design, and compiled design-and-simulate objectives.

The important user-facing transform contracts here are:

  • linearize() and linearize_ss() are designed to compose with jit and vmap
  • DiscLTI.dt stays as an array leaf rather than static metadata, which makes batching over systems more natural
  • composition helpers are ordinary state-space algebra once family and shape checks pass

Numerical Notes

Contrax does not enable float64 globally on import. Precision-sensitive system helpers such as c2d() require jax_enable_x64=True.

For nonlinear local models, remember that linearize() and linearize_ss() produce local behavior around one operating point; they are not a guarantee of global model fidelity.

The structured-system diagnostics are local algebraic checks. They are useful for sanity checking model definitions and recursive-estimation plumbing, but they are not a formal certificate of passivity or robustness.

contrax.systems

Public systems namespace.

This module gathers the system-model concepts that users typically look for first: LTI containers and constructors, nonlinear model containers, basic interconnection, and linearization helpers.

ContLTI

Bases: Module

Continuous-time LTI system: ẋ = Ax + Bu, y = Cx + Du.

Source code in contrax/core.py
class ContLTI(eqx.Module):
    """Continuous-time LTI system: ẋ = Ax + Bu, y = Cx + Du."""

    A: Array  # (n, n)
    B: Array  # (n, m)
    C: Array  # (p, n)
    D: Array  # (p, m)

    def __matmul__(self, other):
        from contrax.interconnect import series

        return series(self, other)

    def __add__(self, other):
        from contrax.interconnect import parallel

        return parallel(self, other)

    def __sub__(self, other):
        from contrax.interconnect import parallel

        return parallel(self, other, sign=-1.0)

DiscLTI

Bases: Module

Discrete-time LTI system: x_{k+1} = Ax_k + Bu_k, y_k = Cx_k + Du_k.

Source code in contrax/core.py
class DiscLTI(eqx.Module):
    """Discrete-time LTI system: x_{k+1} = Ax_k + Bu_k, y_k = Cx_k + Du_k."""

    A: Array  # (n, n) — state transition matrix (Phi)
    B: Array  # (n, m) — input matrix (Gamma)
    C: Array  # (p, n)
    D: Array  # (p, m)
    dt: Array  # scalar — sampling period as plain JAX array, NOT static

    def __matmul__(self, other):
        from contrax.interconnect import series

        return series(self, other)

    def __add__(self, other):
        from contrax.interconnect import parallel

        return parallel(self, other)

    def __sub__(self, other):
        from contrax.interconnect import parallel

        return parallel(self, other, sign=-1.0)

NonlinearSystem

Bases: Module

General nonlinear system container.

The reusable model-object contract in Contrax is:

  • dynamics(t, x, u) for the state transition or vector field
  • output(t, x, u) for the measured/output map

The constructor keeps the underlying field name observation so existing pytrees stay stable, while the public constructor prefers the more control-oriented output= spelling.

Source code in contrax/nonlinear.py
class NonlinearSystem(eqx.Module):
    """General nonlinear system container.

    The reusable model-object contract in Contrax is:

    - `dynamics(t, x, u)` for the state transition or vector field
    - `output(t, x, u)` for the measured/output map

    The constructor keeps the underlying field name `observation` so existing
    pytrees stay stable, while the public constructor prefers the more
    control-oriented `output=` spelling.
    """

    dynamics: Callable
    observation: Callable | None = None
    dt: Array | None = None
    state_dim: int | None = eqx.field(static=True, default=None)
    input_dim: int | None = eqx.field(static=True, default=None)
    output_dim: int | None = eqx.field(static=True, default=None)

    def output(self, t: Array | float, x: Array, u: Array) -> Array:
        """Evaluate the observation map, defaulting to the full state."""
        if self.observation is None:
            return x
        return self.observation(t, x, u)

output(t: Array | float, x: Array, u: Array) -> Array

Evaluate the observation map, defaulting to the full state.

Source code in contrax/nonlinear.py
def output(self, t: Array | float, x: Array, u: Array) -> Array:
    """Evaluate the observation map, defaulting to the full state."""
    if self.observation is None:
        return x
    return self.observation(t, x, u)

PHSSystem

Bases: Module

Port-Hamiltonian system with user-supplied storage and structure maps.

Source code in contrax/phs.py
class PHSSystem(eqx.Module):
    """Port-Hamiltonian system with user-supplied storage and structure maps."""

    H: Callable[[Array], Array]
    R: Callable[[Array], Array] | None = None
    G: Callable[[Array], Array] | None = None
    J: Callable[[Array], Array] | None = None
    observation: Callable | None = None
    dt: Array | None = None
    # Dimension metadata. static=True is appropriate here because these are
    # structural integers, not array leaves. Consequence: vmapping over a batch
    # of PHSSystem instances requires all systems in the batch to share the same
    # (state_dim, input_dim, output_dim) values or JAX will retrace per shape.
    state_dim: int | None = eqx.field(static=True, default=None)
    input_dim: int | None = eqx.field(static=True, default=None)
    output_dim: int | None = eqx.field(static=True, default=None)

    def dynamics(self, t: Array | float, x: Array, u: Array) -> Array:
        """Evaluate the port-Hamiltonian vector field."""
        del t  # canonical first pass is time-invariant
        dim = x.shape[0]
        J = self.J(x) if self.J is not None else canonical_J(dim // 2, dtype=x.dtype)
        R = self.R(x) if self.R is not None else jnp.zeros((dim, dim), dtype=x.dtype)
        G = self.G(x) if self.G is not None else _zero_input_matrix(x, u)
        grad_H = jax.grad(self.H)(x)
        return (J - R) @ grad_H + G @ u

    def output(self, t: Array | float, x: Array, u: Array) -> Array:
        """Evaluate the observation map, defaulting to the full state."""
        if self.observation is None:
            return x
        return self.observation(t, x, u)

    def as_nonlinear_system(self) -> NonlinearSystem:
        """Expose the PHS model under the generic nonlinear-system contract."""
        return NonlinearSystem(
            dynamics=self.dynamics,
            observation=self.output,
            dt=self.dt,
            state_dim=self.state_dim,
            input_dim=self.input_dim,
            output_dim=self.output_dim,
        )

as_nonlinear_system() -> NonlinearSystem

Expose the PHS model under the generic nonlinear-system contract.

Source code in contrax/phs.py
def as_nonlinear_system(self) -> NonlinearSystem:
    """Expose the PHS model under the generic nonlinear-system contract."""
    return NonlinearSystem(
        dynamics=self.dynamics,
        observation=self.output,
        dt=self.dt,
        state_dim=self.state_dim,
        input_dim=self.input_dim,
        output_dim=self.output_dim,
    )

dynamics(t: Array | float, x: Array, u: Array) -> Array

Evaluate the port-Hamiltonian vector field.

Source code in contrax/phs.py
def dynamics(self, t: Array | float, x: Array, u: Array) -> Array:
    """Evaluate the port-Hamiltonian vector field."""
    del t  # canonical first pass is time-invariant
    dim = x.shape[0]
    J = self.J(x) if self.J is not None else canonical_J(dim // 2, dtype=x.dtype)
    R = self.R(x) if self.R is not None else jnp.zeros((dim, dim), dtype=x.dtype)
    G = self.G(x) if self.G is not None else _zero_input_matrix(x, u)
    grad_H = jax.grad(self.H)(x)
    return (J - R) @ grad_H + G @ u

output(t: Array | float, x: Array, u: Array) -> Array

Evaluate the observation map, defaulting to the full state.

Source code in contrax/phs.py
def output(self, t: Array | float, x: Array, u: Array) -> Array:
    """Evaluate the observation map, defaulting to the full state."""
    if self.observation is None:
        return x
    return self.observation(t, x, u)

block_matrix(row_block_sizes: Sequence[int], col_block_sizes: Sequence[int], blocks: dict[tuple[int, int], ArrayLike] | None = None, *, dtype: dtype | None = None) -> Array

Assemble a dense matrix from block entries.

Source code in contrax/phs.py
def block_matrix(
    row_block_sizes: Sequence[int],
    col_block_sizes: Sequence[int],
    blocks: dict[tuple[int, int], ArrayLike] | None = None,
    *,
    dtype: jnp.dtype | None = None,
) -> Array:
    """Assemble a dense matrix from block entries."""
    row_sizes = tuple(int(size) for size in row_block_sizes)
    col_sizes = tuple(int(size) for size in col_block_sizes)
    if not row_sizes or not col_sizes:
        raise ValueError("row_block_sizes and col_block_sizes must not be empty.")
    if any(size <= 0 for size in row_sizes + col_sizes):
        raise ValueError("Block sizes must contain only positive integers.")

    block_entries = {} if blocks is None else dict(blocks)

    # First pass: determine the fully promoted dtype across all user-supplied blocks.
    result_dtype = jnp.float32 if dtype is None else dtype
    for block_val in block_entries.values():
        result_dtype = jnp.result_type(result_dtype, jnp.asarray(block_val).dtype)

    # Second pass: build rows with a consistent dtype so zero-fill blocks match.
    rows = []
    for i, row_size in enumerate(row_sizes):
        row = []
        for j, col_size in enumerate(col_sizes):
            if (i, j) in block_entries:
                block = jnp.asarray(block_entries[(i, j)], dtype=result_dtype)
                if block.shape != (row_size, col_size):
                    raise ValueError(
                        f"Block {(i, j)} has shape {block.shape}, expected "
                        f"{(row_size, col_size)}."
                    )
            else:
                block = jnp.zeros((row_size, col_size), dtype=result_dtype)
            row.append(block)
        rows.append(row)

    return jnp.block(rows)

block_observation(block_sizes: Sequence[int], block_indices: Sequence[int]) -> Callable[[Array | float, Array, Array], Array]

Build an output map that concatenates selected state blocks.

Source code in contrax/phs.py
def block_observation(
    block_sizes: Sequence[int],
    block_indices: Sequence[int],
) -> Callable[[Array | float, Array, Array], Array]:
    """Build an output map that concatenates selected state blocks."""
    sizes = tuple(int(size) for size in block_sizes)
    indices = tuple(int(index) for index in block_indices)
    if not sizes:
        raise ValueError("block_sizes must not be empty.")
    if not indices:
        raise ValueError("block_indices must not be empty.")
    if any(index < 0 or index >= len(sizes) for index in indices):
        raise ValueError("block_indices must refer to valid state blocks.")

    def output(t: Array | float, x: Array, u: Array) -> Array:
        del t, u
        parts = partition_state(x, sizes)
        return jnp.concatenate([parts[index] for index in indices], axis=0)

    return output

c2d(sys: ContLTI, dt: float, method: str = 'zoh') -> DiscLTI

Discretize a continuous-time state-space system.

Converts a ContLTI into a DiscLTI using either zero-order hold or the bilinear Tustin transform.

The zoh path is the stronger current implementation and is differentiated through a custom VJP on the matrix exponential. tustin is available as a convenience discretization.

Parameters:

Name Type Description Default
sys ContLTI

Continuous-time system.

required
dt float

Sampling period. Shape: scalar.

required
method str

Discretization method. Supported values are "zoh" and "tustin".

'zoh'

Returns:

Type Description
DiscLTI

DiscLTI: A discrete-time realization with the same output matrices and the requested sampling period.

Raises:

Type Description
ValueError

If method is not one of the supported discretizations.

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys_c = cx.ss(
...     jnp.array([[0.0, 1.0], [0.0, 0.0]]),
...     jnp.array([[0.0], [1.0]]),
...     jnp.eye(2),
...     jnp.zeros((2, 1)),
... )
>>> sys_d = cx.c2d(sys_c, dt=0.05)
Source code in contrax/core.py
def c2d(sys: ContLTI, dt: float, method: str = "zoh") -> DiscLTI:
    """Discretize a continuous-time state-space system.

    Converts a [ContLTI][contrax.systems.ContLTI] into a
    [DiscLTI][contrax.systems.DiscLTI] using either zero-order hold or the
    bilinear Tustin transform.

    The `zoh` path is the stronger current implementation and is differentiated
    through a custom VJP on the matrix exponential. `tustin` is available as a
    convenience discretization.

    Args:
        sys: Continuous-time system.
        dt: Sampling period. Shape: scalar.
        method: Discretization method. Supported values are `"zoh"` and
            `"tustin"`.

    Returns:
        [DiscLTI][contrax.systems.DiscLTI]: A discrete-time realization with
            the same output matrices and the requested sampling period.

    Raises:
        ValueError: If `method` is not one of the supported discretizations.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> sys_c = cx.ss(
        ...     jnp.array([[0.0, 1.0], [0.0, 0.0]]),
        ...     jnp.array([[0.0], [1.0]]),
        ...     jnp.eye(2),
        ...     jnp.zeros((2, 1)),
        ... )
        >>> sys_d = cx.c2d(sys_c, dt=0.05)
    """
    require_x64("c2d")
    if method == "zoh":
        return _c2d_zoh(sys, dt)
    elif method == "tustin":
        return _c2d_tustin(sys, dt)
    else:
        raise ValueError(f"Unknown discretization method: {method!r}")

canonical_J(n_dofs: int, *, dtype: dtype | None = None) -> Array

Construct the canonical skew-symmetric PHS structure matrix.

Source code in contrax/phs.py
def canonical_J(n_dofs: int, *, dtype: jnp.dtype | None = None) -> Array:
    """Construct the canonical skew-symmetric PHS structure matrix."""
    if n_dofs < 1:
        raise ValueError("n_dofs must be positive.")
    dtype = jnp.float32 if dtype is None else dtype
    eye = jnp.eye(n_dofs, dtype=dtype)
    zeros = jnp.zeros((n_dofs, n_dofs), dtype=dtype)
    return jnp.block([[zeros, eye], [-eye, zeros]])

dss(A: ArrayLike, B: ArrayLike, C: ArrayLike, D: ArrayLike, dt: ArrayLike) -> DiscLTI

Construct a discrete-time state-space system.

Wraps discrete-time matrices in a DiscLTI pytree with MATLAB-familiar ss(A, B, C, D, dt) ergonomics.

This is a lightweight constructor rather than a solver. The sampling period is stored as a plain array leaf instead of static metadata, which keeps batching over systems more natural in JAX.

Parameters:

Name Type Description Default
A ArrayLike

State transition matrix. Shape: (n, n).

required
B ArrayLike

Input matrix. Shape: (n, m).

required
C ArrayLike

Output matrix. Shape: (p, n).

required
D ArrayLike

Feedthrough matrix. Shape: (p, m).

required
dt ArrayLike

Sampling period. Shape: scalar.

required

Returns:

Type Description
DiscLTI

DiscLTI: A discrete-time system with array-valued A, B, C, D, and dt fields.

Raises:

Type Description
AssertionError

If the matrix shapes are inconsistent.

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys = cx.dss(
...     jnp.array([[1.0, 0.1], [0.0, 1.0]]),
...     jnp.array([[0.0], [0.1]]),
...     jnp.eye(2),
...     jnp.zeros((2, 1)),
...     dt=0.1,
... )
Source code in contrax/core.py
def dss(
    A: ArrayLike,
    B: ArrayLike,
    C: ArrayLike,
    D: ArrayLike,
    dt: ArrayLike,
) -> DiscLTI:
    """Construct a discrete-time state-space system.

    Wraps discrete-time matrices in a
    [DiscLTI][contrax.systems.DiscLTI] pytree with MATLAB-familiar
    `ss(A, B, C, D, dt)` ergonomics.

    This is a lightweight constructor rather than a solver. The sampling period
    is stored as a plain array leaf instead of static metadata, which keeps
    batching over systems more natural in JAX.

    Args:
        A: State transition matrix. Shape: `(n, n)`.
        B: Input matrix. Shape: `(n, m)`.
        C: Output matrix. Shape: `(p, n)`.
        D: Feedthrough matrix. Shape: `(p, m)`.
        dt: Sampling period. Shape: scalar.

    Returns:
        [DiscLTI][contrax.systems.DiscLTI]: A discrete-time system with
            array-valued `A`, `B`, `C`, `D`, and `dt` fields.

    Raises:
        AssertionError: If the matrix shapes are inconsistent.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> sys = cx.dss(
        ...     jnp.array([[1.0, 0.1], [0.0, 1.0]]),
        ...     jnp.array([[0.0], [0.1]]),
        ...     jnp.eye(2),
        ...     jnp.zeros((2, 1)),
        ...     dt=0.1,
        ... )
    """
    A = jnp.asarray(A, dtype=float)
    B = jnp.asarray(B, dtype=float)
    C = jnp.asarray(C, dtype=float)
    D = jnp.asarray(D, dtype=float)
    dt = jnp.asarray(dt, dtype=float)
    n = A.shape[0]
    assert A.shape == (n, n)
    assert B.shape[0] == n
    assert C.shape[1] == n
    assert D.shape == (C.shape[0], B.shape[1])
    return DiscLTI(A=A, B=B, C=C, D=D, dt=dt)

linearize(model_or_f: DynamicsLike, x0: Array, u0: Array, *, t: ArrayLike = 0.0) -> tuple[Array, Array]

Linearize a dynamics function around an operating point.

Computes the Jacobians A = df/dx and B = df/du at (x0, u0) using forward-mode automatic differentiation.

This is one of the main JAX-native workflow bridges in Contrax. It works cleanly with jit, vmap, and differentiation as long as f is written as a pure JAX function without Python control flow over traced values.

Parameters:

Name Type Description Default
model_or_f DynamicsLike

A nonlinear system model or pure JAX dynamics function mapping (x, u) to x_dot or x_next.

required
x0 Array

State operating point. Shape: (n,).

required
u0 Array

Input operating point. Shape: (m,).

required
t ArrayLike

Operating-point time for system-model linearization.

0.0

Returns:

Type Description
Array

tuple[Array, Array]: Jacobian pair (A, B). Shapes: (n, n) and

Array

(n, m).

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> def dynamics(x, u):
...     return jnp.array([x[1], -jnp.sin(x[0]) + u[0]])
>>> A, B = cx.linearize(dynamics, jnp.zeros(2), jnp.zeros(1))
Source code in contrax/core.py
def linearize(
    model_or_f: DynamicsLike,
    x0: Array,
    u0: Array,
    *,
    t: ArrayLike = 0.0,
) -> tuple[Array, Array]:
    """Linearize a dynamics function around an operating point.

    Computes the Jacobians `A = df/dx` and `B = df/du` at `(x0, u0)` using
    forward-mode automatic differentiation.

    This is one of the main JAX-native workflow bridges in Contrax. It works
    cleanly with `jit`, `vmap`, and differentiation as long as `f` is written
    as a pure JAX function without Python control flow over traced values.

    Args:
        model_or_f: A nonlinear system model or pure JAX dynamics function
            mapping `(x, u)` to `x_dot` or `x_next`.
        x0: State operating point. Shape: `(n,)`.
        u0: Input operating point. Shape: `(m,)`.
        t: Operating-point time for system-model linearization.

    Returns:
        tuple[Array, Array]: Jacobian pair `(A, B)`. Shapes: `(n, n)` and
        `(n, m)`.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> def dynamics(x, u):
        ...     return jnp.array([x[1], -jnp.sin(x[0]) + u[0]])
        >>> A, B = cx.linearize(dynamics, jnp.zeros(2), jnp.zeros(1))
    """
    f = _coerce_dynamics(model_or_f)
    t = jnp.asarray(t, dtype=x0.dtype)
    A = jax.jacfwd(f, argnums=0)(x0, u0, t)
    B = jax.jacfwd(f, argnums=1)(x0, u0, t)
    return A, B

linearize_ss(model_or_f: DynamicsLike, x0: Array, u0: Array, *, output: OutputFn | None = None, t: ArrayLike = 0.0) -> ContLTI

Linearize dynamics and output maps into a continuous state-space model.

Computes A, B, C, and D Jacobians at (x0, u0) and returns them as a ContLTI.

This is the main bridge from nonlinear plant code into the linear control stack. Pair it with c2d() and lqr() when building local controllers around an operating point.

Parameters:

Name Type Description Default
model_or_f DynamicsLike

A nonlinear system model or pure JAX dynamics function mapping (x, u) to x_dot.

required
x0 Array

State operating point. Shape: (n,).

required
u0 Array

Input operating point. Shape: (m,).

required
output OutputFn | None

Output function mapping (x, u) to y when model_or_f is a plain dynamics callable. Omit this when passing a system model with its own output map.

None
t ArrayLike

Operating-point time for system-model linearization.

0.0

Returns:

Type Description
ContLTI

ContLTI: A continuous-time state-space model with linearized A, B, C, and D matrices.

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> def dynamics(x, u):
...     return jnp.array([x[1], -jnp.sin(x[0]) + u[0]])
>>> def output(x, u):
...     return x
>>> sys_c = cx.linearize_ss(
...     dynamics,
...     jnp.zeros(2),
...     jnp.zeros(1),
...     output=output,
... )
Source code in contrax/core.py
def linearize_ss(
    model_or_f: DynamicsLike,
    x0: Array,
    u0: Array,
    *,
    output: OutputFn | None = None,
    t: ArrayLike = 0.0,
) -> ContLTI:
    """Linearize dynamics and output maps into a continuous state-space model.

    Computes `A`, `B`, `C`, and `D` Jacobians at `(x0, u0)` and returns them as
    a [ContLTI][contrax.systems.ContLTI].

    This is the main bridge from nonlinear plant code into the linear control
    stack. Pair it with `c2d()` and `lqr()` when building local controllers
    around an operating point.

    Args:
        model_or_f: A nonlinear system model or pure JAX dynamics function
            mapping `(x, u)` to `x_dot`.
        x0: State operating point. Shape: `(n,)`.
        u0: Input operating point. Shape: `(m,)`.
        output: Output function mapping `(x, u)` to `y` when `model_or_f` is a
            plain dynamics callable. Omit this when passing a system model with
            its own output map.
        t: Operating-point time for system-model linearization.

    Returns:
        [ContLTI][contrax.systems.ContLTI]: A continuous-time state-space
            model with linearized `A`, `B`, `C`, and `D` matrices.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> def dynamics(x, u):
        ...     return jnp.array([x[1], -jnp.sin(x[0]) + u[0]])
        >>> def output(x, u):
        ...     return x
        >>> sys_c = cx.linearize_ss(
        ...     dynamics,
        ...     jnp.zeros(2),
        ...     jnp.zeros(1),
        ...     output=output,
        ... )
    """
    if _is_system_model(model_or_f):
        if output is not None:
            raise ValueError("linearize_ss(sys, x0, u0) does not accept output=.")
        sys = model_or_f
        f = _coerce_dynamics(sys)
        h = _coerce_observation(sys)
    else:
        if output is None:
            raise ValueError(
                "linearize_ss(f, x0, u0, output=...) requires an output function."
            )
        f = _coerce_dynamics(model_or_f)

        def h(x: Array, u: Array, t_: Array) -> Array:
            del t_
            return output(x, u)

    t = jnp.asarray(t, dtype=x0.dtype)
    A = jax.jacfwd(f, argnums=0)(x0, u0, t)
    B = jax.jacfwd(f, argnums=1)(x0, u0, t)
    C = jax.jacfwd(h, argnums=0)(x0, u0, t)
    D = jax.jacfwd(h, argnums=1)(x0, u0, t)
    return ContLTI(A=A, B=B, C=C, D=D)

nonlinear_system(dynamics: Callable, output: Callable | None = None, *, observation: Callable | None = None, dt: ArrayLike | None = None, state_dim: int | None = None, input_dim: int | None = None, output_dim: int | None = None) -> NonlinearSystem

Construct a reusable nonlinear system model.

Contrax uses NonlinearSystem as the shared contract for nonlinear simulation, linearization, and estimation. The intended callable shape is dynamics(t, x, u) and, when present, output(t, x, u).

Parameters:

Name Type Description Default
dynamics Callable

State transition or vector-field callable with signature (t, x, u).

required
output Callable | None

Optional output/measurement map with signature (t, x, u). When omitted, the full state is used as the output.

None
observation Callable | None

Deprecated synonym for output. Pass only one of output or observation.

None
dt ArrayLike | None

Optional discrete sample time. Omit for continuous models.

None
state_dim int | None

Optional static state dimension metadata.

None
input_dim int | None

Optional static input dimension metadata.

None
output_dim int | None

Optional static output dimension metadata.

None

Returns:

Type Description
NonlinearSystem

NonlinearSystem: A reusable nonlinear system model.

Source code in contrax/nonlinear.py
def nonlinear_system(
    dynamics: Callable,
    output: Callable | None = None,
    *,
    observation: Callable | None = None,
    dt: ArrayLike | None = None,
    state_dim: int | None = None,
    input_dim: int | None = None,
    output_dim: int | None = None,
) -> NonlinearSystem:
    """Construct a reusable nonlinear system model.

    Contrax uses [NonlinearSystem][contrax.systems.NonlinearSystem] as the
    shared contract for nonlinear simulation, linearization, and estimation.
    The intended callable shape is `dynamics(t, x, u)` and, when present,
    `output(t, x, u)`.

    Args:
        dynamics: State transition or vector-field callable with signature
            `(t, x, u)`.
        output: Optional output/measurement map with signature `(t, x, u)`.
            When omitted, the full state is used as the output.
        observation: Deprecated synonym for `output`. Pass only one of
            `output` or `observation`.
        dt: Optional discrete sample time. Omit for continuous models.
        state_dim: Optional static state dimension metadata.
        input_dim: Optional static input dimension metadata.
        output_dim: Optional static output dimension metadata.

    Returns:
        [NonlinearSystem][contrax.systems.NonlinearSystem]: A reusable
            nonlinear system model.
    """
    if output is not None and observation is not None:
        raise ValueError("Pass only one of output= or observation=.")
    observation = output if output is not None else observation
    dt_arr = None if dt is None else jnp.asarray(dt, dtype=float)
    return NonlinearSystem(
        dynamics=dynamics,
        observation=observation,
        dt=dt_arr,
        state_dim=state_dim,
        input_dim=input_dim,
        output_dim=output_dim,
    )

parallel(sys1: DiscLTI | ContLTI, sys2: DiscLTI | ContLTI, *, sign: float = 1.0) -> DiscLTI | ContLTI

Connect two systems in parallel and sum their outputs.

Parameters:

Name Type Description Default
sys1 DiscLTI | ContLTI

First system.

required
sys2 DiscLTI | ContLTI

Second system.

required
sign float

Output sign applied to sys2. Use -1.0 for subtraction.

1.0

Returns:

Type Description
DiscLTI | ContLTI

DiscLTI | ContLTI: Interconnected system.

Source code in contrax/interconnect.py
def parallel(
    sys1: DiscLTI | ContLTI,
    sys2: DiscLTI | ContLTI,
    *,
    sign: float = 1.0,
) -> DiscLTI | ContLTI:
    """Connect two systems in parallel and sum their outputs.

    Args:
        sys1: First system.
        sys2: Second system.
        sign: Output sign applied to `sys2`. Use `-1.0` for subtraction.

    Returns:
        [DiscLTI][contrax.systems.DiscLTI] |
            [ContLTI][contrax.systems.ContLTI]: Interconnected system.
    """
    _check_same_family(sys1, sys2)
    if sys1.B.shape[1] != sys2.B.shape[1]:
        raise ValueError("Parallel interconnection requires matching input dimensions.")
    if sys1.C.shape[0] != sys2.C.shape[0]:
        raise ValueError(
            "Parallel interconnection requires matching output dimensions."
        )

    A1, B1, C1, D1 = sys1.A, sys1.B, sys1.C, sys1.D
    A2, B2, C2, D2 = sys2.A, sys2.B, sys2.C, sys2.D
    n1, n2 = A1.shape[0], A2.shape[0]
    sign_arr = jnp.asarray(sign, dtype=A1.dtype)

    A = jnp.block(
        [
            [A1, jnp.zeros((n1, n2), dtype=A1.dtype)],
            [jnp.zeros((n2, n1), dtype=A2.dtype), A2],
        ]
    )
    B = jnp.vstack([B1, B2])
    C = jnp.hstack([C1, sign_arr * C2])
    D = D1 + sign_arr * D2

    if isinstance(sys1, DiscLTI):
        return DiscLTI(A=A, B=B, C=C, D=D, dt=sys1.dt)
    return ContLTI(A=A, B=B, C=C, D=D)

partition_state(x: ArrayLike, block_sizes: Sequence[int]) -> tuple[Array, ...]

Split a one-dimensional state vector into consecutive logical blocks.

Source code in contrax/phs.py
def partition_state(x: ArrayLike, block_sizes: Sequence[int]) -> tuple[Array, ...]:
    """Split a one-dimensional state vector into consecutive logical blocks."""
    x = jnp.asarray(x)
    sizes = tuple(int(size) for size in block_sizes)
    if x.ndim != 1:
        raise ValueError("partition_state() expects a one-dimensional state vector.")
    if not sizes or any(size <= 0 for size in sizes):
        raise ValueError("block_sizes must contain only positive integers.")
    total = sum(sizes)
    if x.shape[0] != total:
        raise ValueError(
            f"State length {x.shape[0]} does not match block_sizes sum {total}."
        )
    offsets = [0]
    for size in sizes:
        offsets.append(offsets[-1] + size)
    return tuple(x[offsets[i] : offsets[i + 1]] for i in range(len(sizes)))

phs_diagnostics(sys: PHSSystem, x: ArrayLike, u: ArrayLike | None = None) -> PHSStructureDiagnostics

Evaluate local structure and power-balance diagnostics for a PHS state.

Source code in contrax/phs.py
def phs_diagnostics(
    sys: PHSSystem,
    x: ArrayLike,
    u: ArrayLike | None = None,
) -> PHSStructureDiagnostics:
    """Evaluate local structure and power-balance diagnostics for a PHS state."""
    x = jnp.asarray(x)
    u_arr = jnp.zeros((0,), dtype=x.dtype) if u is None else jnp.asarray(u)
    dim = x.shape[0]
    J = sys.J(x) if sys.J is not None else canonical_J(dim // 2, dtype=x.dtype)
    R = sys.R(x) if sys.R is not None else jnp.zeros((dim, dim), dtype=x.dtype)
    G = sys.G(x) if sys.G is not None else _zero_input_matrix(x, u_arr)
    grad_H = jax.grad(sys.H)(x)
    R_sym = symmetrize_matrix(R)
    # Reuse already-computed J, R, G rather than calling sys.dynamics() again.
    Gu = G @ u_arr
    xdot = (J - R) @ grad_H + Gu
    storage_rate = jnp.vdot(grad_H, xdot)
    dissipation_power = jnp.vdot(grad_H, R_sym @ grad_H)
    supplied_power = jnp.vdot(grad_H, Gu)
    # Standard PHS energy balance: dH/dt = supply - dissipation
    # => supplied_power - dissipation_power - storage_rate = 0
    return PHSStructureDiagnostics(
        skew_symmetry_error=jnp.max(jnp.abs(J + J.T)),
        dissipation_symmetry_error=jnp.max(jnp.abs(R - R.T)),
        min_dissipation_eigenvalue=jnp.min(jnp.linalg.eigvalsh(R_sym)),
        storage_rate=storage_rate,
        dissipation_power=dissipation_power,
        supplied_power=supplied_power,
        power_balance_residual=supplied_power - dissipation_power - storage_rate,
    )

phs_system(H: Callable[[Array], Array], *, R: Callable[[Array], Array] | None = None, G: Callable[[Array], Array] | None = None, J: Callable[[Array], Array] | None = None, output: Callable | None = None, observation: Callable | None = None, dt: ArrayLike | None = None, state_dim: int | None = None, input_dim: int | None = None, output_dim: int | None = None) -> PHSSystem

Construct a port-Hamiltonian system object.

Parameters:

Name Type Description Default
H Callable[[Array], Array]

Storage/Hamiltonian function.

required
R Callable[[Array], Array] | None

Optional dissipation map.

None
G Callable[[Array], Array] | None

Optional input map.

None
J Callable[[Array], Array] | None

Optional structure map.

None
output Callable | None

Optional output/measurement map with signature (t, x, u).

None
observation Callable | None

Deprecated synonym for output. Pass only one of output or observation.

None
dt ArrayLike | None

Optional discrete sample time.

None
state_dim int | None

Optional static state dimension metadata.

None
input_dim int | None

Optional static input dimension metadata.

None
output_dim int | None

Optional static output dimension metadata.

None

Returns:

Type Description
PHSSystem

PHSSystem: A reusable port-Hamiltonian model.

Source code in contrax/phs.py
def phs_system(
    H: Callable[[Array], Array],
    *,
    R: Callable[[Array], Array] | None = None,
    G: Callable[[Array], Array] | None = None,
    J: Callable[[Array], Array] | None = None,
    output: Callable | None = None,
    observation: Callable | None = None,
    dt: ArrayLike | None = None,
    state_dim: int | None = None,
    input_dim: int | None = None,
    output_dim: int | None = None,
) -> PHSSystem:
    """Construct a port-Hamiltonian system object.

    Args:
        H: Storage/Hamiltonian function.
        R: Optional dissipation map.
        G: Optional input map.
        J: Optional structure map.
        output: Optional output/measurement map with signature `(t, x, u)`.
        observation: Deprecated synonym for `output`. Pass only one of
            `output` or `observation`.
        dt: Optional discrete sample time.
        state_dim: Optional static state dimension metadata.
        input_dim: Optional static input dimension metadata.
        output_dim: Optional static output dimension metadata.

    Returns:
        [PHSSystem][contrax.systems.PHSSystem]: A reusable port-Hamiltonian
            model.
    """
    if output is not None and observation is not None:
        raise ValueError("Pass only one of output= or observation=.")
    observation = output if output is not None else observation
    dt_arr = None if dt is None else jnp.asarray(dt, dtype=float)
    return PHSSystem(
        H=H,
        R=R,
        G=G,
        J=J,
        observation=observation,
        dt=dt_arr,
        state_dim=state_dim,
        input_dim=input_dim,
        output_dim=output_dim,
    )

phs_to_ss(sys: PHSSystem, x_eq: Array, u_eq: Array) -> ContLTI

Linearize a port-Hamiltonian system to a continuous-time state-space model.

Computes the Jacobian linearization of the PHS dynamics and observation map at the operating point (x_eq, u_eq). The result is the standard ẋ ≈ A δx + B δu / y ≈ C δx + D δu form around that point.

For a linear PHS with quadratic Hamiltonian H(x) = ½ x^T M x, the linearization is exact and gives A = (J - R) M, B = G.

This is a thin convenience wrapper around linearize_ss(sys.as_nonlinear_system(), x_eq, u_eq, t=0.0). It adds an explicit PHS-to-state-space path so that the intent is documented and the connection between the PHS model and its LTI state-space representation is clear.

Parameters:

Name Type Description Default
sys PHSSystem

Port-Hamiltonian system.

required
x_eq Array

Equilibrium state. Shape: (n,).

required
u_eq Array

Equilibrium input. Shape: (m,).

required

Returns:

Type Description
ContLTI

ContLTI: Continuous-time LTI approximation

ContLTI

at (x_eq, u_eq).

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> def H(x): return 0.5 * jnp.dot(x, x)
>>> phs = cx.phs_system(H, state_dim=2, input_dim=1)
>>> lti = cx.phs_to_ss(phs, jnp.zeros(2), jnp.zeros(1))
>>> isinstance(lti, cx.ContLTI)
True
Source code in contrax/phs.py
def phs_to_ss(
    sys: PHSSystem,
    x_eq: Array,
    u_eq: Array,
) -> ContLTI:
    """Linearize a port-Hamiltonian system to a continuous-time state-space model.

    Computes the Jacobian linearization of the PHS dynamics and observation map
    at the operating point `(x_eq, u_eq)`. The result is the standard
    `ẋ ≈ A δx + B δu` / `y ≈ C δx + D δu` form around that point.

    For a linear PHS with quadratic Hamiltonian `H(x) = ½ x^T M x`, the
    linearization is exact and gives `A = (J - R) M`, `B = G`.

    This is a thin convenience wrapper around
    `linearize_ss(sys.as_nonlinear_system(), x_eq, u_eq, t=0.0)`. It adds
    an explicit PHS-to-state-space path so that the intent is documented and
    the connection between the PHS model and its LTI state-space representation
    is clear.

    Args:
        sys: Port-Hamiltonian system.
        x_eq: Equilibrium state. Shape: `(n,)`.
        u_eq: Equilibrium input. Shape: `(m,)`.

    Returns:
        [ContLTI][contrax.systems.ContLTI]: Continuous-time LTI approximation
        at `(x_eq, u_eq)`.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> def H(x): return 0.5 * jnp.dot(x, x)
        >>> phs = cx.phs_system(H, state_dim=2, input_dim=1)
        >>> lti = cx.phs_to_ss(phs, jnp.zeros(2), jnp.zeros(1))
        >>> isinstance(lti, cx.ContLTI)
        True
    """
    from contrax.core import linearize_ss

    nl = sys.as_nonlinear_system()
    return linearize_ss(nl, x_eq, u_eq, t=0.0)

project_psd(M: ArrayLike, *, min_eigenvalue: float = 0.0) -> Array

Project a square matrix onto the symmetric PSD cone.

Source code in contrax/phs.py
def project_psd(
    M: ArrayLike,
    *,
    min_eigenvalue: float = 0.0,
) -> Array:
    """Project a square matrix onto the symmetric PSD cone."""
    M = symmetrize_matrix(M)
    eigvals, eigvecs = jnp.linalg.eigh(M)
    clipped = jnp.maximum(eigvals, jnp.asarray(min_eigenvalue, dtype=eigvals.dtype))
    return symmetrize_matrix((eigvecs * clipped) @ eigvecs.T)

schedule_phs(sys: PHSSystem, context_fn: Callable[[Array | float], Any], *, J: Callable[[Array | float, Array, Any], Array] | None = None, R: Callable[[Array | float, Array, Any], Array] | None = None, G: Callable[[Array | float, Array, Any], Array] | None = None, observation: Callable[[Array | float, Array, Array, Any], Array] | None = None) -> NonlinearSystem

Bind an exogenous schedule/context to a canonical PHS model.

The returned object is an ordinary NonlinearSystem, so it composes with the same simulation, estimation, and linearization surface as any other nonlinear system in Contrax.

context_fn(t) should be a pure JAX-friendly callable returning any pytree of observed scheduling/context variables, such as a scale parameter theta(t).

Parameters:

Name Type Description Default
sys PHSSystem

Canonical PHS model with state-only H(x) and optional state-only J(x) / R(x) / G(x) fallbacks.

required
context_fn Callable[[Array | float], Any]

Callable mapping time to observed scheduling/context data.

required
J Callable[[Array | float, Array, Any], Array] | None

Optional scheduled structure map J(t, x, context).

None
R Callable[[Array | float, Array, Any], Array] | None

Optional scheduled dissipation map R(t, x, context).

None
G Callable[[Array | float, Array, Any], Array] | None

Optional scheduled input map G(t, x, context).

None
observation Callable[[Array | float, Array, Array, Any], Array] | None

Optional scheduled observation map observation(t, x, u, context).

None

Returns:

Type Description
NonlinearSystem

NonlinearSystem: A scheduled nonlinear system with the context bound.

Source code in contrax/phs.py
def schedule_phs(
    sys: PHSSystem,
    context_fn: Callable[[Array | float], Any],
    *,
    J: Callable[[Array | float, Array, Any], Array] | None = None,
    R: Callable[[Array | float, Array, Any], Array] | None = None,
    G: Callable[[Array | float, Array, Any], Array] | None = None,
    observation: Callable[[Array | float, Array, Array, Any], Array] | None = None,
) -> NonlinearSystem:
    """Bind an exogenous schedule/context to a canonical PHS model.

    The returned object is an ordinary
    [NonlinearSystem][contrax.systems.NonlinearSystem], so it composes with the
    same simulation, estimation, and linearization surface as any other
    nonlinear system in Contrax.

    `context_fn(t)` should be a pure JAX-friendly callable returning any pytree
    of observed scheduling/context variables, such as a scale parameter
    `theta(t)`.

    Args:
        sys: Canonical PHS model with state-only `H(x)` and optional state-only
            `J(x)` / `R(x)` / `G(x)` fallbacks.
        context_fn: Callable mapping time to observed scheduling/context data.
        J: Optional scheduled structure map `J(t, x, context)`.
        R: Optional scheduled dissipation map `R(t, x, context)`.
        G: Optional scheduled input map `G(t, x, context)`.
        observation: Optional scheduled observation map
            `observation(t, x, u, context)`.

    Returns:
        [NonlinearSystem][contrax.systems.NonlinearSystem]: A scheduled
            nonlinear system with the context bound.
    """

    def dynamics(t: Array | float, x: Array, u: Array) -> Array:
        context = context_fn(t)
        dim = x.shape[0]
        J_mat = (
            J(t, x, context)
            if J is not None
            else (
                sys.J(x) if sys.J is not None else canonical_J(dim // 2, dtype=x.dtype)
            )
        )
        R_mat = (
            R(t, x, context)
            if R is not None
            else (
                sys.R(x) if sys.R is not None else jnp.zeros((dim, dim), dtype=x.dtype)
            )
        )
        G_mat = (
            G(t, x, context)
            if G is not None
            else (sys.G(x) if sys.G is not None else _zero_input_matrix(x, u))
        )
        grad_H = jax.grad(sys.H)(x)
        return (J_mat - R_mat) @ grad_H + G_mat @ u

    def scheduled_observation(t: Array | float, x: Array, u: Array) -> Array:
        if observation is not None:
            context = context_fn(t)
            return observation(t, x, u, context)
        if sys.observation is None:
            return x
        return sys.observation(t, x, u)

    return NonlinearSystem(
        dynamics=dynamics,
        observation=scheduled_observation,
        dt=sys.dt,
        state_dim=sys.state_dim,
        input_dim=sys.input_dim,
        output_dim=sys.output_dim,
    )

series(sys2: DiscLTI | ContLTI, sys1: DiscLTI | ContLTI) -> DiscLTI | ContLTI

Connect two systems in series as sys2 @ sys1.

The output of sys1 feeds the input of sys2. State ordering follows the MATLAB-style convention: sys1 states first, then sys2 states.

Parameters:

Name Type Description Default
sys2 DiscLTI | ContLTI

Downstream system.

required
sys1 DiscLTI | ContLTI

Upstream system.

required

Returns:

Type Description
DiscLTI | ContLTI

DiscLTI | ContLTI: Interconnected system.

Source code in contrax/interconnect.py
def series(sys2: DiscLTI | ContLTI, sys1: DiscLTI | ContLTI) -> DiscLTI | ContLTI:
    """Connect two systems in series as `sys2 @ sys1`.

    The output of `sys1` feeds the input of `sys2`. State ordering follows the
    MATLAB-style convention: `sys1` states first, then `sys2` states.

    Args:
        sys2: Downstream system.
        sys1: Upstream system.

    Returns:
        [DiscLTI][contrax.systems.DiscLTI] |
            [ContLTI][contrax.systems.ContLTI]: Interconnected system.
    """
    _check_same_family(sys2, sys1)
    if sys1.C.shape[0] != sys2.B.shape[1]:
        raise ValueError(
            "Series interconnection requires sys1 output dimension to match "
            "sys2 input dimension."
        )

    A1, B1, C1, D1 = sys1.A, sys1.B, sys1.C, sys1.D
    A2, B2, C2, D2 = sys2.A, sys2.B, sys2.C, sys2.D
    n1, n2 = A1.shape[0], A2.shape[0]

    A = jnp.block(
        [
            [A1, jnp.zeros((n1, n2), dtype=A1.dtype)],
            [B2 @ C1, A2],
        ]
    )
    B = jnp.vstack([B1, B2 @ D1])
    C = jnp.hstack([D2 @ C1, C2])
    D = D2 @ D1

    if isinstance(sys1, DiscLTI):
        return DiscLTI(A=A, B=B, C=C, D=D, dt=sys1.dt)
    return ContLTI(A=A, B=B, C=C, D=D)

ss(A: ArrayLike, B: ArrayLike, C: ArrayLike, D: ArrayLike) -> ContLTI

Construct a continuous-time state-space system.

Wraps continuous-time matrices in a ContLTI pytree with MATLAB-familiar ss(A, B, C, D) ergonomics.

This is a lightweight constructor rather than a solver. It is safe under normal JAX transforms with fixed-shape array inputs because it only wraps validated arrays into a data-only eqx.Module.

Parameters:

Name Type Description Default
A ArrayLike

State matrix. Shape: (n, n).

required
B ArrayLike

Input matrix. Shape: (n, m).

required
C ArrayLike

Output matrix. Shape: (p, n).

required
D ArrayLike

Feedthrough matrix. Shape: (p, m).

required

Returns:

Type Description
ContLTI

ContLTI: A continuous-time system with array-valued A, B, C, and D fields.

Raises:

Type Description
AssertionError

If the matrix shapes are inconsistent.

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys = cx.ss(
...     jnp.array([[0.0, 1.0], [0.0, 0.0]]),
...     jnp.array([[0.0], [1.0]]),
...     jnp.eye(2),
...     jnp.zeros((2, 1)),
... )
Source code in contrax/core.py
def ss(A: ArrayLike, B: ArrayLike, C: ArrayLike, D: ArrayLike) -> ContLTI:
    """Construct a continuous-time state-space system.

    Wraps continuous-time matrices in a
    [ContLTI][contrax.systems.ContLTI] pytree with MATLAB-familiar
    `ss(A, B, C, D)` ergonomics.

    This is a lightweight constructor rather than a solver. It is safe under
    normal JAX transforms with fixed-shape array inputs because it only wraps
    validated arrays into a data-only `eqx.Module`.

    Args:
        A: State matrix. Shape: `(n, n)`.
        B: Input matrix. Shape: `(n, m)`.
        C: Output matrix. Shape: `(p, n)`.
        D: Feedthrough matrix. Shape: `(p, m)`.

    Returns:
        [ContLTI][contrax.systems.ContLTI]: A continuous-time system with
            array-valued `A`, `B`, `C`, and `D` fields.

    Raises:
        AssertionError: If the matrix shapes are inconsistent.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> sys = cx.ss(
        ...     jnp.array([[0.0, 1.0], [0.0, 0.0]]),
        ...     jnp.array([[0.0], [1.0]]),
        ...     jnp.eye(2),
        ...     jnp.zeros((2, 1)),
        ... )
    """
    A = jnp.asarray(A, dtype=float)
    B = jnp.asarray(B, dtype=float)
    C = jnp.asarray(C, dtype=float)
    D = jnp.asarray(D, dtype=float)
    n = A.shape[0]
    assert A.shape == (n, n), f"A must be square, got {A.shape}"
    assert B.shape[0] == n, f"B rows must match A, got {B.shape}"
    assert C.shape[1] == n, f"C cols must match A, got {C.shape}"
    assert D.shape == (C.shape[0], B.shape[1]), f"D shape mismatch, got {D.shape}"
    return ContLTI(A=A, B=B, C=C, D=D)

symmetrize_matrix(M: ArrayLike) -> Array

Return the symmetric part 0.5 * (M + M.T) of a square matrix.

Source code in contrax/phs.py
def symmetrize_matrix(M: ArrayLike) -> Array:
    """Return the symmetric part `0.5 * (M + M.T)` of a square matrix."""
    M = jnp.asarray(M)
    if M.ndim != 2 or M.shape[0] != M.shape[1]:
        raise ValueError("symmetrize_matrix() expects a square matrix.")
    return 0.5 * (M + M.T)