Skip to content

Control

Control is the public namespace for controller design and Riccati-backed feedback synthesis in Contrax.

The namespace includes:

  • dare() and care() for algebraic Riccati solves
  • lqr() and lqi() for regulator design
  • augment_integrator() for explicit integral-state augmentation
  • place() for design-time pole placement
  • state_feedback() for applying a state-feedback gain to an LTI model

Minimal Example

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import contrax as cx

sys = cx.dss(
    jnp.array([[1.0, 0.05], [0.0, 1.0]]),
    jnp.array([[0.0], [0.05]]),
    jnp.eye(2),
    jnp.zeros((2, 1)),
    dt=0.05,
)
result = cx.lqr(sys, jnp.eye(2), jnp.array([[1.0]]))
closed_loop = cx.state_feedback(sys, result.K)
Workflow from plant model and quadratic weights through dare or care, gain recovery, and closed-loop validation
The main control path: combine a linear model with quadratic weights, solve the appropriate Riccati equation, recover the stabilizing gain, then inspect the closed-loop poles and residual.

Conventions

For lqi(), the returned gain acts on the augmented state [x; z], where z is the integral state added by augment_integrator().

Control Equations

Contrax's control surface is centered on state-feedback design for linear systems.

For the discrete-time model

\[ x_{k+1} = A x_k + B u_k \]

the infinite-horizon quadratic objective used by dare(), lqr(), and discrete lqi() is

\[ J_d = \sum_{k=0}^{\infty} \left(x_k^\top Q x_k + u_k^\top R u_k\right) \]

with state-feedback law

\[ u_k = -K x_k \]

For the continuous-time model

\[ \dot{x} = A x + B u \]

the corresponding continuous objective behind care() and continuous lqr() is

\[ J_c = \int_0^\infty \left(x(t)^\top Q x(t) + u(t)^\top R u(t)\right)\,dt \]

with feedback law

\[ u(t) = -K x(t) \]

Riccati Equations

The stabilizing discrete Riccati equation is

\[ A^\top S A - S - A^\top S B \left(R + B^\top S B\right)^{-1} B^\top S A + Q = 0 \]

and the continuous Riccati equation is

\[ A^\top S + S A - S B R^{-1} B^\top S + Q = 0 \]

Given the stabilizing solution S, the corresponding gains are

\[ K_d = \left(R + B^\top S B\right)^{-1} B^\top S A \]
\[ K_c = R^{-1} B^\top S \]

Closed-Loop Models

state_feedback() applies the gain directly to the plant matrices.

For discrete systems,

\[ x_{k+1} = (A - B K) x_k \]

and for continuous systems,

\[ \dot{x} = (A - B K) x \]

That is the model whose poles appear in LQRResult.poles.

Output-feedback architecture with controller, plant, measurement, and estimator feeding a state estimate back to the controller
Where state feedback fits: `state_feedback()` handles the plant-side algebra, while estimation supplies the state estimate when the full state is not measured directly.

Integral Augmentation

augment_integrator() and lqi() introduce an integral state z driven by an output-tracking error. At the level of the augmented state

\[ \bar{x} = \begin{bmatrix} x \\ z \end{bmatrix} \]

the returned gain acts as

\[ u = -\bar{K}\,\bar{x} \]

with \(\bar{K}\) partitioned over the physical and integral states.

Solver Maturity

Contrax is explicit about solver maturity:

  • dare() is the most mature Riccati path
  • care() is supported and validated, but less benchmarked than dare()
  • place() is suitable as a design-time helper, not as a hardened large-scale assignment solver

If the numerical context matters to your workflow, read Riccati solvers before treating all paths as equally mature.

Transform Behavior

dare(), care(), and lqr() are designed to participate in compiled JAX workflows.

The key contracts are:

  • lqr() dispatches on DiscLTI vs ContLTI outside traced runtime values
  • the Riccati paths use custom or implicit VJPs rather than blindly unrolling the forward solve in reverse mode
  • state_feedback() is just state-space algebra on the model matrices

That makes controller design inside a JIT-compiled objective a first-class use case rather than an accidental one.

Numerical Notes

Contrax does not enable float64 globally on import. Riccati solves require jax_enable_x64=True.

state_feedback(sys, K) is specifically about applying a state-feedback gain to an LTI model. It is not the general interconnection operator for arbitrary controller/plant block diagrams.

For integral-action designs, augment_integrator() exposes the augmented model directly and lqi() is the thin convenience wrapper on top.

contrax.control

contrax.control — LQR, pole placement, closed-loop construction.

augment_integrator(sys: DiscLTI | ContLTI, C_integral: Array | None = None, D_integral: Array | None = None, *, sign: float = 1.0, dt_scale: ArrayLike | None = None) -> DiscLTI | ContLTI

Augment an LTI system with integral-of-output states.

This is the small design utility behind LQI-style workflows. It returns an augmented system that can be passed to lqr() with a state cost that also weights the added integral states.

For continuous systems, the added state evolves as z_dot = sign * (C_integral x + D_integral u). For discrete systems, the added state evolves as z[k+1] = z[k] + dt_scale * sign * (C_integral x[k] + D_integral u[k]). When dt_scale is omitted for a discrete system, sys.dt is used.

Parameters:

Name Type Description Default
sys DiscLTI | ContLTI

Continuous or discrete LTI system to augment.

required
C_integral Array | None

Output map to integrate. Shape: (r, n). Defaults to sys.C.

None
D_integral Array | None

Feedthrough map for the integrated output. Shape: (r, m). Defaults to sys.D.

None
sign float

Sign applied to the integrated output. Use -1.0 for the common reference - output convention when the reference is handled outside the design model.

1.0
dt_scale ArrayLike | None

Discrete integration scale. Defaults to sys.dt for discrete systems and is ignored for continuous systems.

None

Returns:

Type Description
DiscLTI | ContLTI

DiscLTI | ContLTI: System with augmented state [x; z].

Source code in contrax/control.py
def augment_integrator(
    sys: DiscLTI | ContLTI,
    C_integral: Array | None = None,
    D_integral: Array | None = None,
    *,
    sign: float = 1.0,
    dt_scale: ArrayLike | None = None,
) -> DiscLTI | ContLTI:
    """Augment an LTI system with integral-of-output states.

    This is the small design utility behind LQI-style workflows. It returns an
    augmented system that can be passed to `lqr()` with a state cost that also
    weights the added integral states.

    For continuous systems, the added state evolves as
    `z_dot = sign * (C_integral x + D_integral u)`. For discrete systems, the
    added state evolves as
    `z[k+1] = z[k] + dt_scale * sign * (C_integral x[k] + D_integral u[k])`.
    When `dt_scale` is omitted for a discrete system, `sys.dt` is used.

    Args:
        sys: Continuous or discrete LTI system to augment.
        C_integral: Output map to integrate. Shape: `(r, n)`. Defaults to
            `sys.C`.
        D_integral: Feedthrough map for the integrated output. Shape: `(r, m)`.
            Defaults to `sys.D`.
        sign: Sign applied to the integrated output. Use `-1.0` for the common
            `reference - output` convention when the reference is handled
            outside the design model.
        dt_scale: Discrete integration scale. Defaults to `sys.dt` for
            discrete systems and is ignored for continuous systems.

    Returns:
        [DiscLTI][contrax.systems.DiscLTI] |
            [ContLTI][contrax.systems.ContLTI]: System with augmented state
            `[x; z]`.
    """
    C_int = sys.C if C_integral is None else jnp.asarray(C_integral, dtype=sys.A.dtype)
    D_int = sys.D if D_integral is None else jnp.asarray(D_integral, dtype=sys.A.dtype)
    scale = jnp.asarray(sign, dtype=sys.A.dtype)

    n = sys.A.shape[0]
    r = C_int.shape[0]
    zeros_xz = jnp.zeros((n, r), dtype=sys.A.dtype)
    eye_z = jnp.eye(r, dtype=sys.A.dtype)
    C_aug = jnp.hstack([sys.C, jnp.zeros((sys.C.shape[0], r), dtype=sys.A.dtype)])

    if isinstance(sys, ContLTI):
        A_aug = jnp.block(
            [
                [sys.A, zeros_xz],
                [scale * C_int, jnp.zeros((r, r), dtype=sys.A.dtype)],
            ]
        )
        B_aug = jnp.vstack([sys.B, scale * D_int])
        return ContLTI(A=A_aug, B=B_aug, C=C_aug, D=sys.D)

    if isinstance(sys, DiscLTI):
        dt = sys.dt if dt_scale is None else jnp.asarray(dt_scale, dtype=sys.A.dtype)
        row_scale = scale * dt
        A_aug = jnp.block([[sys.A, zeros_xz], [row_scale * C_int, eye_z]])
        B_aug = jnp.vstack([sys.B, row_scale * D_int])
        return DiscLTI(A=A_aug, B=B_aug, C=C_aug, D=sys.D, dt=sys.dt)

    raise TypeError(f"Expected DiscLTI or ContLTI, got {type(sys)}")

care(A: Array, B: Array, Q: Array, R: Array) -> LQRResult

Solve the continuous algebraic Riccati equation.

Computes the stabilizing Riccati solution and corresponding continuous-time LQR gain for the matrix tuple (A, B, Q, R).

The forward solve uses a Hamiltonian stable-subspace construction with jnp.linalg.eig, and gradients are provided through an implicit-differentiation custom VJP. This is now a real baseline continuous-time solver, but it is not yet as hardened as dare().

Parameters:

Name Type Description Default
A Array

Continuous-time state matrix. Shape: (n, n).

required
B Array

Continuous-time input matrix. Shape: (n, m).

required
Q Array

State cost matrix. Shape: (n, n).

required
R Array

Input cost matrix. Shape: (m, m).

required

Returns:

Type Description
LQRResult

LQRResult: A bundle containing the optimal feedback gain K, Riccati solution S, and closed-loop poles poles.

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> A = jnp.array([[0.0, 1.0], [0.0, 0.0]])
>>> B = jnp.array([[0.0], [1.0]])
>>> result = cx.care(A, B, jnp.eye(2), jnp.array([[1.0]]))
Source code in contrax/_riccati.py
def care(A: Array, B: Array, Q: Array, R: Array) -> LQRResult:
    """Solve the continuous algebraic Riccati equation.

    Computes the stabilizing Riccati solution and corresponding continuous-time
    LQR gain for the matrix tuple `(A, B, Q, R)`.

    The forward solve uses a Hamiltonian stable-subspace construction with
    `jnp.linalg.eig`, and gradients are provided through an
    implicit-differentiation custom VJP. This is now a real baseline
    continuous-time solver, but it is not yet as hardened as `dare()`.

    Args:
        A: Continuous-time state matrix. Shape: `(n, n)`.
        B: Continuous-time input matrix. Shape: `(n, m)`.
        Q: State cost matrix. Shape: `(n, n)`.
        R: Input cost matrix. Shape: `(m, m)`.

    Returns:
        [LQRResult][contrax.types.LQRResult]: A bundle containing the optimal
            feedback gain `K`, Riccati solution `S`, and closed-loop poles
            `poles`.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> A = jnp.array([[0.0, 1.0], [0.0, 0.0]])
        >>> B = jnp.array([[0.0], [1.0]])
        >>> result = cx.care(A, B, jnp.eye(2), jnp.array([[1.0]]))
    """
    require_x64("care")
    S = _care_solve(A, B, Q, R)
    K = jnp.linalg.solve(_symmetrize(R), B.T @ S)
    A_cl = A - B @ K
    poles = jnp.linalg.eigvals(A_cl)
    residual_norm = jnp.max(jnp.abs(_care_residual(S, A, B, Q, R)))
    if not _is_tracing(A, B, Q, R, S, poles):
        _validate_care_solution(A, B, Q, R, S, poles)
    return LQRResult(K=K, S=S, poles=poles, residual_norm=residual_norm)

dare(A: Array, B: Array, Q: Array, R: Array, max_iter: int = 64, tol: float = 1e-10) -> LQRResult

Solve the discrete algebraic Riccati equation.

Computes the stabilizing Riccati solution and corresponding infinite-horizon discrete LQR gain for the matrix tuple (A, B, Q, R).

This is currently the strongest Riccati solver path in Contrax. The forward solve uses structured doubling with tolerance-based stopping, and gradients are provided through a custom VJP built from implicit differentiation of the Riccati residual.

Parameters:

Name Type Description Default
A Array

Discrete-time state transition matrix. Shape: (n, n).

required
B Array

Discrete-time input matrix. Shape: (n, m).

required
Q Array

State cost matrix. Shape: (n, n).

required
R Array

Input cost matrix. Shape: (m, m).

required
max_iter int

Maximum structured-doubling iterations.

64
tol float

Forward stopping tolerance on Riccati iterate changes.

1e-10

Returns:

Type Description
LQRResult

LQRResult: A bundle containing the optimal feedback gain K, Riccati solution S, and closed-loop poles poles.

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> A = jnp.array([[1.0, 0.05], [0.0, 1.0]])
>>> B = jnp.array([[0.0], [0.05]])
>>> result = cx.dare(A, B, jnp.eye(2), jnp.array([[1.0]]))
Source code in contrax/_riccati.py
def dare(
    A: Array,
    B: Array,
    Q: Array,
    R: Array,
    max_iter: int = 64,
    tol: float = 1e-10,
) -> LQRResult:
    """Solve the discrete algebraic Riccati equation.

    Computes the stabilizing Riccati solution and corresponding infinite-horizon
    discrete LQR gain for the matrix tuple `(A, B, Q, R)`.

    This is currently the strongest Riccati solver path in Contrax. The
    forward solve uses structured doubling with tolerance-based stopping, and
    gradients are provided through a custom VJP built from implicit
    differentiation of the Riccati residual.

    Args:
        A: Discrete-time state transition matrix. Shape: `(n, n)`.
        B: Discrete-time input matrix. Shape: `(n, m)`.
        Q: State cost matrix. Shape: `(n, n)`.
        R: Input cost matrix. Shape: `(m, m)`.
        max_iter: Maximum structured-doubling iterations.
        tol: Forward stopping tolerance on Riccati iterate changes.

    Returns:
        [LQRResult][contrax.types.LQRResult]: A bundle containing the optimal
            feedback gain `K`, Riccati solution `S`, and closed-loop poles
            `poles`.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> A = jnp.array([[1.0, 0.05], [0.0, 1.0]])
        >>> B = jnp.array([[0.0], [0.05]])
        >>> result = cx.dare(A, B, jnp.eye(2), jnp.array([[1.0]]))
    """
    require_x64("dare")
    S = _dare_structured_doubling_solve(A, B, Q, R, max_iter, tol)
    K = _lqr_gain(A, B, S, _symmetrize(R))
    A_cl = A - B @ K
    poles = jnp.linalg.eigvals(A_cl)
    residual_norm = jnp.max(jnp.abs(_dare_residual(S, A, B, Q, R)))
    return LQRResult(K=K, S=S, poles=poles, residual_norm=residual_norm)

feedback(sys: DiscLTI | ContLTI, K: Array) -> DiscLTI | ContLTI

Alias for state_feedback().

This alias remains available for MATLAB-style ergonomics, but the more explicit state_feedback() name is preferred in new Contrax code because it distinguishes gain application from general closed-loop interconnection.

Source code in contrax/control.py
def feedback(sys: DiscLTI | ContLTI, K: Array) -> DiscLTI | ContLTI:
    """Alias for `state_feedback()`.

    This alias remains available for MATLAB-style ergonomics, but the more
    explicit `state_feedback()` name is preferred in new Contrax code because it
    distinguishes gain application from general closed-loop interconnection.
    """
    return state_feedback(sys, K)

lqi(sys: DiscLTI | ContLTI, Q: Array, R: Array, C_integral: Array | None = None, D_integral: Array | None = None, *, sign: float = 1.0, dt_scale: ArrayLike | None = None) -> LQRResult

Design an LQI gain by augmenting integral states and calling lqr().

This is intentionally a thin composition of augment_integrator() and lqr(). The returned gain acts on the augmented state [x; z], where z is the integral state.

Parameters:

Name Type Description Default
sys DiscLTI | ContLTI

Continuous or discrete LTI system.

required
Q Array

Augmented state cost matrix. Shape: (n + r, n + r).

required
R Array

Input cost matrix. Shape: (m, m).

required
C_integral Array | None

Output map to integrate. Shape: (r, n). Defaults to sys.C.

None
D_integral Array | None

Feedthrough map for the integrated output. Shape: (r, m). Defaults to sys.D.

None
sign float

Sign applied to the integrated output.

1.0
dt_scale ArrayLike | None

Discrete integration scale. Defaults to sys.dt for discrete systems and is ignored for continuous systems.

None

Returns:

Type Description
LQRResult

LQRResult: LQR result for the integrator-augmented design model.

Source code in contrax/control.py
def lqi(
    sys: DiscLTI | ContLTI,
    Q: Array,
    R: Array,
    C_integral: Array | None = None,
    D_integral: Array | None = None,
    *,
    sign: float = 1.0,
    dt_scale: ArrayLike | None = None,
) -> LQRResult:
    """Design an LQI gain by augmenting integral states and calling `lqr()`.

    This is intentionally a thin composition of `augment_integrator()` and
    `lqr()`. The returned gain acts on the augmented state `[x; z]`, where `z`
    is the integral state.

    Args:
        sys: Continuous or discrete LTI system.
        Q: Augmented state cost matrix. Shape: `(n + r, n + r)`.
        R: Input cost matrix. Shape: `(m, m)`.
        C_integral: Output map to integrate. Shape: `(r, n)`. Defaults to
            `sys.C`.
        D_integral: Feedthrough map for the integrated output. Shape: `(r, m)`.
            Defaults to `sys.D`.
        sign: Sign applied to the integrated output.
        dt_scale: Discrete integration scale. Defaults to `sys.dt` for
            discrete systems and is ignored for continuous systems.

    Returns:
        [LQRResult][contrax.types.LQRResult]: LQR result for the
            integrator-augmented design model.
    """
    augmented = augment_integrator(
        sys,
        C_integral=C_integral,
        D_integral=D_integral,
        sign=sign,
        dt_scale=dt_scale,
    )
    return lqr(augmented, Q, R)

lqr(sys: DiscLTI | ContLTI, Q: Array, R: Array) -> LQRResult

Solve the infinite-horizon linear quadratic regulator problem.

Dispatches to dare() for discrete systems and care() for continuous systems, returning a unified LQRResult.

The discrete path is more benchmarked than the continuous path, but both routes are first-class public solvers. Type dispatch happens outside traced runtime values, so jit works as long as the system type is fixed at trace time.

Parameters:

Name Type Description Default
sys DiscLTI | ContLTI

DiscLTI or ContLTI system.

required
Q Array

State cost matrix. Shape: (n, n).

required
R Array

Input cost matrix. Shape: (m, m).

required

Returns:

Type Description
LQRResult

LQRResult: A bundle containing the optimal gain, Riccati solution, and closed-loop poles.

Raises:

Type Description
TypeError

If sys is not a DiscLTI or ContLTI.

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys = cx.dss(
...     jnp.array([[1.0, 0.05], [0.0, 1.0]]),
...     jnp.array([[0.0], [0.05]]),
...     jnp.eye(2),
...     jnp.zeros((2, 1)),
...     dt=0.05,
... )
>>> result = cx.lqr(sys, jnp.eye(2), jnp.array([[1.0]]))
Source code in contrax/control.py
def lqr(sys: DiscLTI | ContLTI, Q: Array, R: Array) -> LQRResult:
    """Solve the infinite-horizon linear quadratic regulator problem.

    Dispatches to `dare()` for discrete systems and `care()` for continuous
    systems, returning a unified [LQRResult][contrax.types.LQRResult].

    The discrete path is more benchmarked than the continuous path, but both
    routes are first-class public solvers. Type dispatch happens outside traced
    runtime values, so `jit` works as long as the system type is fixed at trace
    time.

    Args:
        sys: [DiscLTI][contrax.systems.DiscLTI] or
            [ContLTI][contrax.systems.ContLTI] system.
        Q: State cost matrix. Shape: `(n, n)`.
        R: Input cost matrix. Shape: `(m, m)`.

    Returns:
        [LQRResult][contrax.types.LQRResult]: A bundle containing the optimal
            gain, Riccati solution, and closed-loop poles.

    Raises:
        TypeError: If `sys` is not a
            [DiscLTI][contrax.systems.DiscLTI] or
            [ContLTI][contrax.systems.ContLTI].

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> sys = cx.dss(
        ...     jnp.array([[1.0, 0.05], [0.0, 1.0]]),
        ...     jnp.array([[0.0], [0.05]]),
        ...     jnp.eye(2),
        ...     jnp.zeros((2, 1)),
        ...     dt=0.05,
        ... )
        >>> result = cx.lqr(sys, jnp.eye(2), jnp.array([[1.0]]))
    """
    if isinstance(sys, DiscLTI):
        return dare(sys.A, sys.B, Q, R)
    elif isinstance(sys, ContLTI):
        return care(sys.A, sys.B, Q, R)
    else:
        raise TypeError(f"Expected DiscLTI or ContLTI, got {type(sys)}")

place(sys: DiscLTI | ContLTI, poles: ArrayLike, *, method: str | None = None, rtol: float = 0.001, maxiter: int = 30) -> Array

Place closed-loop poles for a state-space system.

Uses a JAX-native robust assignment path based on KNV0 for real-pole cases and YT for complex-pole cases. Both methods are design-time helpers rather than traced differentiable primitives.

If method is left as None, Contrax selects KNV0 for all-real pole sets and YT when any complex-conjugate pair is present. For single-input systems, Ackermann remains available as a fallback if the robust path fails.

Parameters:

Name Type Description Default
sys DiscLTI | ContLTI

Continuous or discrete-time system.

required
poles ArrayLike

Desired closed-loop pole locations. Shape: (n,).

required
method str | None

Optional robust assignment method. Supported values are "KNV0" and "YT". When omitted, Contrax chooses based on whether the requested poles are all real.

None
rtol float

Determinant-improvement stopping tolerance for the robust update.

0.001
maxiter int

Maximum robust-assignment iterations.

30

Returns:

Name Type Description
Array Array

State-feedback gain L. Shape: (m, n).

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys = cx.dss(
...     jnp.array([[1.0, 0.1], [0.0, 0.9]]),
...     jnp.array([[0.0], [1.0]]),
...     jnp.eye(2),
...     jnp.zeros((2, 1)),
... )
>>> L = cx.place(sys, jnp.array([0.6, 0.7]))
Source code in contrax/_place.py
def place(
    sys: DiscLTI | ContLTI,
    poles: ArrayLike,
    *,
    method: str | None = None,
    rtol: float = 1e-3,
    maxiter: int = 30,
) -> Array:
    """Place closed-loop poles for a state-space system.

    Uses a JAX-native robust assignment path based on KNV0 for real-pole cases
    and YT for complex-pole cases. Both methods are design-time helpers rather
    than traced differentiable primitives.

    If `method` is left as `None`, Contrax selects `KNV0` for all-real pole
    sets and `YT` when any complex-conjugate pair is present. For single-input
    systems, Ackermann remains available as a fallback if the robust path
    fails.

    Args:
        sys: Continuous or discrete-time system.
        poles: Desired closed-loop pole locations. Shape: `(n,)`.
        method: Optional robust assignment method. Supported values are
            `"KNV0"` and `"YT"`. When omitted, Contrax chooses based on
            whether the requested poles are all real.
        rtol: Determinant-improvement stopping tolerance for the robust update.
        maxiter: Maximum robust-assignment iterations.

    Returns:
        Array: State-feedback gain `L`. Shape: `(m, n)`.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> sys = cx.dss(
        ...     jnp.array([[1.0, 0.1], [0.0, 0.9]]),
        ...     jnp.array([[0.0], [1.0]]),
        ...     jnp.eye(2),
        ...     jnp.zeros((2, 1)),
        ... )
        >>> L = cx.place(sys, jnp.array([0.6, 0.7]))
    """
    A, B = sys.A, sys.B
    poles_ordered = _order_complex_poles(poles)
    has_complex = bool(np.any(np.imag(np.asarray(poles_ordered)) != 0.0))
    chosen_method = method or ("YT" if has_complex else "KNV0")

    try:
        return _place_robust(
            A,
            B,
            poles_ordered,
            method=chosen_method,
            rtol=rtol,
            maxiter=maxiter,
        )
    except Exception:
        if B.shape[1] == 1:
            return _place_ackermann(A, B, poles_ordered)
        raise

state_feedback(sys: DiscLTI | ContLTI, K: Array) -> DiscLTI | ContLTI

Construct the closed-loop system under state feedback.

Returns a new LTI system with A_cl = A - B @ K and unchanged B, C, D, and, for discrete systems, dt.

This is a pure functional constructor rather than a simulation routine. It dispatches on either DiscLTI or ContLTI and behaves like normal array algebra under JAX transforms.

Parameters:

Name Type Description Default
sys DiscLTI | ContLTI

Continuous or discrete-time system.

required
K Array

State-feedback gain. Shape: (m, n).

required

Returns:

Type Description
DiscLTI | ContLTI

DiscLTI | ContLTI: A new system representing the closed-loop dynamics.

Examples:

>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys = cx.dss(
...     jnp.array([[1.0, 0.05], [0.0, 1.0]]),
...     jnp.array([[0.0], [0.05]]),
...     jnp.eye(2),
...     jnp.zeros((2, 1)),
...     dt=0.05,
... )
>>> result = cx.lqr(sys, jnp.eye(2), jnp.array([[1.0]]))
>>> closed_loop = cx.state_feedback(sys, result.K)
Source code in contrax/control.py
def state_feedback(sys: DiscLTI | ContLTI, K: Array) -> DiscLTI | ContLTI:
    """Construct the closed-loop system under state feedback.

    Returns a new LTI system with `A_cl = A - B @ K` and unchanged `B`, `C`,
    `D`, and, for discrete systems, `dt`.

    This is a pure functional constructor rather than a simulation routine. It
    dispatches on either [DiscLTI][contrax.systems.DiscLTI] or
    [ContLTI][contrax.systems.ContLTI] and behaves like normal array algebra
    under JAX transforms.

    Args:
        sys: Continuous or discrete-time system.
        K: State-feedback gain. Shape: `(m, n)`.

    Returns:
        [DiscLTI][contrax.systems.DiscLTI] |
            [ContLTI][contrax.systems.ContLTI]: A new system representing the
            closed-loop dynamics.

    Examples:
        >>> import jax.numpy as jnp
        >>> import contrax as cx
        >>> sys = cx.dss(
        ...     jnp.array([[1.0, 0.05], [0.0, 1.0]]),
        ...     jnp.array([[0.0], [0.05]]),
        ...     jnp.eye(2),
        ...     jnp.zeros((2, 1)),
        ...     dt=0.05,
        ... )
        >>> result = cx.lqr(sys, jnp.eye(2), jnp.array([[1.0]]))
        >>> closed_loop = cx.state_feedback(sys, result.K)
    """
    if isinstance(sys, DiscLTI):
        return DiscLTI(
            A=sys.A - sys.B @ K,
            B=sys.B,
            C=sys.C,
            D=sys.D,
            dt=sys.dt,
        )
    elif isinstance(sys, ContLTI):
        return ContLTI(
            A=sys.A - sys.B @ K,
            B=sys.B,
            C=sys.C,
            D=sys.D,
        )
    else:
        raise TypeError(f"Expected DiscLTI or ContLTI, got {type(sys)}")