Simulation¶
Simulation is the public namespace for open-loop, closed-loop, and fixed-horizon trajectory generation in Contrax.
The namespace includes:
lsim()for open-loop discrete LTI simulationsimulate()for closed-loop discrete and continuous simulationrollout()for fixed-horizon nonlinear transitions without a system objectsample_system()for solver-backed continuous-to-discrete nonlinear model constructionfoh_inputs()for first-order-hold endpoint pairing on sampled inputsstep_response(),impulse_response(), andinitial_response()for standard response viewsas_ode_term()as a lower-level bridge into direct Diffrax use
Minimal Example¶
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import contrax as cx
sys = cx.dss(
jnp.array([[1.0, 0.05], [0.0, 1.0]]),
jnp.array([[0.0], [0.05]]),
jnp.eye(2),
jnp.zeros((2, 1)),
dt=0.05,
)
K = cx.lqr(sys, jnp.eye(2), jnp.array([[1.0]])).K
ts, xs, ys = cx.simulate(
sys,
jnp.array([1.0, 0.0]),
lambda t, x: -K @ x,
num_steps=80,
)
Horizon Conventions¶
simulate() is explicit about the horizon type:
- use
num_steps=...for discrete systems - use
duration=...for continuous systems
For returned trajectories:
- discrete
simulate()returnsxsincludingx0 - continuous
simulate()returns samples on a fixed save grid lsim()returns(ts, xs, ys)for an open-loop input sequence- response helpers return
(ts, ys)
Use rollout() when there is no system object and you simply want to scan a
transition function over a fixed input sequence.
Use sample_system() when the underlying plant model is continuous-time but
the estimation or control workflow lives on a discrete grid. The returned
object is a discrete NonlinearSystem, so
it plugs directly into ekf(), ukf(), simulate(..., num_steps=...), and
other discrete-time workflows.
With the default input_interpolation="zoh", each step input has shape (m,)
and is held constant across the interval. With
input_interpolation="foh", each step input has shape (2, m) and is treated
as the endpoint pair (u_k, u_{k+1}); foh_inputs() builds that paired
sequence from ordinary sampled inputs.
Transform Behavior¶
The main transform contracts are:
- discrete
lsim(),simulate(), androllout()are built around fixed-shape JAX scans - continuous
simulate()uses Diffrax internally but keeps the public output on a fixed save grid rollout()is intended to compose withjit,vmap, andgrad
That makes simulation suitable for both ordinary control scripts and compiled design loops.
Numerical Notes¶
Continuous simulate() is intentionally narrow. It exposes the most important
Diffrax escape hatches without turning every simulation call into a solver
configuration exercise.
step_response() and impulse_response() are analysis helpers. They are good
for inspection and sanity-checking, but they are not the preferred path for
optimization over nonsmooth event timing.
Related Pages¶
- Systems for model construction
- Control for controller design before simulation
- JAX transform contract for scan and
solver behavior under
jit,vmap, andgrad - Linearize, LQR, simulate for the main discrete control workflow
contrax.simulation
¶
Public simulation namespace.
as_ode_term(sys: ContLTI | NonlinearSystem | PHSSystem, control_fn: Callable[[float, Array], Array]) -> Any
¶
Wrap a continuous system as a Diffrax ODETerm.
This is the lower-level bridge from
ContLTI into direct Diffrax workflows when the
public simulate() surface is not the right level of control.
This function does not simulate by itself. Its behavior under jit,
vmap, and grad follows the downstream Diffrax solver and adjoint
configuration used by the caller.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sys
|
ContLTI | NonlinearSystem | PHSSystem
|
Continuous-time system. |
required |
control_fn
|
Callable[[float, Array], Array]
|
Control function mapping |
required |
Returns:
| Type | Description |
|---|---|
Any
|
diffrax.ODETerm: An ODE term representing |
Raises:
| Type | Description |
|---|---|
ImportError
|
If Diffrax is not installed. |
Source code in contrax/sim.py
foh_inputs(us: Array) -> Array
¶
Build first-order-hold endpoint pairs from a sample sequence.
Given inputs u_k, returns a sequence with shape (T, 2, m) where each
item contains (u_k, u_{k+1}), using u_{T-1} for the final right
endpoint.
Source code in contrax/sim.py
impulse_response(sys: DiscLTI | ContLTI, *, num_steps: int | None = None, duration: float | None = None, input_index: int = 0, x0: Array | None = None, dt: float | None = None, solver: Any = None, adjoint: Any = None, dt0: float | None = None) -> tuple[Array, Array]
¶
Compute the impulse response of an LTI system.
For discrete systems this applies a unit pulse at k=0. For continuous
systems this applies the equivalent state jump x(0+) = x0 + B[:, i] and
then simulates with zero input. That means the returned sampled trajectory
omits any singular direct-feedthrough D delta(t) spike at t=0.
This is an analysis helper rather than a smooth-input differentiation path. In particular, the continuous-time version uses the standard state-jump convention instead of representing a Dirac delta as an ordinary control signal.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sys
|
DiscLTI | ContLTI
|
Continuous or discrete-time system. |
required |
num_steps
|
int | None
|
Discrete-only number of time steps. |
None
|
duration
|
float | None
|
Continuous-only response duration. |
None
|
input_index
|
int
|
Input channel receiving the impulse. |
0
|
x0
|
Array | None
|
Optional initial state. Defaults to zeros. |
None
|
dt
|
float | None
|
Continuous-only output sample spacing hint. |
None
|
solver
|
Any
|
Continuous-only Diffrax solver override. |
None
|
adjoint
|
Any
|
Continuous-only Diffrax adjoint override. |
None
|
dt0
|
float | None
|
Continuous-only initial solver step size hint. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Array, Array]
|
tuple[Array, Array]: Sample times and output trajectory |
Source code in contrax/sim.py
initial_response(sys: DiscLTI | ContLTI, x0: Array, *, num_steps: int | None = None, duration: float | None = None, dt: float | None = None, solver: Any = None, adjoint: Any = None, dt0: float | None = None) -> tuple[Array, Array]
¶
Compute the zero-input response from a nonzero initial state.
This is the third standard inspection response alongside step and impulse.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sys
|
DiscLTI | ContLTI
|
Continuous or discrete-time system. |
required |
x0
|
Array
|
Initial state. Shape: |
required |
num_steps
|
int | None
|
Discrete-only number of time steps. |
None
|
duration
|
float | None
|
Continuous-only response duration. |
None
|
dt
|
float | None
|
Continuous-only output sample spacing hint. |
None
|
solver
|
Any
|
Continuous-only Diffrax solver override. |
None
|
adjoint
|
Any
|
Continuous-only Diffrax adjoint override. |
None
|
dt0
|
float | None
|
Continuous-only initial solver step size hint. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Array, Array]
|
tuple[Array, Array]: Sample times and output trajectory |
Source code in contrax/sim.py
lsim(sys: DiscLTI, us: Array, x0: Array | None = None) -> tuple[Array, Array, Array]
¶
Simulate a discrete system in open loop from an input sequence.
Mirrors MATLAB-style lsim() for the current discrete state-space surface.
This is the public open-loop simulation name in Contrax. The implementation
is a pure lax.scan, so it composes naturally with jit, vmap, and
differentiation on fixed-shape inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sys
|
DiscLTI
|
Discrete-time system. |
required |
us
|
Array
|
Input sequence. Shape: |
required |
x0
|
Array | None
|
Initial state. Shape: |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
tuple[Array, Array, Array]: Simulation outputs |
Array
|
|
Examples:
>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys = cx.dss(
... jnp.array([[1.0, 0.1], [0.0, 1.0]]),
... jnp.array([[0.0], [0.1]]),
... jnp.eye(2),
... jnp.zeros((2, 1)),
... dt=0.1,
... )
>>> ts, xs, ys = cx.lsim(sys, jnp.zeros((20, 1)))
Source code in contrax/sim.py
rollout(f: Callable[..., Array], x0: Array, us: Array, params: Any = None) -> Array
¶
Roll out arbitrary discrete-time dynamics over a fixed input sequence.
This is the generic nonlinear trajectory primitive in Contrax. Unlike
lsim() and simulate(), it does not require an LTI system object and does
not compute outputs. It simply applies a transition function with
jax.lax.scan.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[..., Array]
|
Transition function. If |
required |
x0
|
Array
|
Initial state. Shape: |
required |
us
|
Array
|
Input sequence. Shape: |
required |
params
|
Any
|
Optional static or array-valued parameters passed to |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
Array |
Array
|
State trajectory including the initial state. Shape: |
Array
|
|
Examples:
>>> import jax.numpy as jnp
>>> import contrax as cx
>>> def f(x, u, gain):
... return gain * x + u
>>> xs = cx.rollout(f, jnp.array([1.0]), jnp.zeros((10, 1)), 0.9)
Source code in contrax/sim.py
sample_system(sys: NonlinearSystem | PHSSystem, dt: float, *, input_interpolation: str = 'zoh', solver: Any = None, adjoint: Any = None, dt0: float | None = None) -> NonlinearSystem
¶
Sample a continuous nonlinear system into a discrete transition model.
This helper builds a discrete-time NonlinearSystem whose one-step dynamics are obtained by integrating the continuous model over one sample interval.
With input_interpolation="zoh", each discrete input has shape (m,) and
is held constant over the interval. With input_interpolation="foh", each
discrete input has shape (2, m) and is interpreted as the pair
(u_k, u_{k+1}); use foh_inputs() to build those pairs from a sampled
input sequence.
Source code in contrax/sim.py
simulate(sys: DiscLTI | ContLTI | NonlinearSystem | PHSSystem, x0: Array, policy: Callable[[float, Array], Array], *, num_steps: int | None = None, duration: float | None = None, dt: float | None = None, solver: Any = None, adjoint: Any = None, dt0: float | None = None) -> tuple[Array, Array, Array]
¶
Simulate a system in closed loop under a control policy.
Evaluates policy(t, x) along the system trajectory and dispatches to the
appropriate simulation path for discrete or continuous LTI models.
For DiscLTI, this keeps the existing
pure-lax.scan path and requires num_steps. For
ContLTI, this uses Diffrax under the hood and
requires duration. Continuous trajectories are sampled on a fixed uniform
save grid so output shapes stay predictable.
The continuous path intentionally exposes only a small amount of Diffrax
configuration: a sample spacing hint dt, plus optional solver,
adjoint, and dt0 overrides. The default continuous solver is
diffrax.Tsit5() with diffrax.RecursiveCheckpointAdjoint().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sys
|
DiscLTI | ContLTI | NonlinearSystem | PHSSystem
|
Continuous or discrete-time system. |
required |
x0
|
Array
|
Initial state. Shape: |
required |
policy
|
Callable[[float, Array], Array]
|
Control policy mapping |
required |
num_steps
|
int | None
|
Discrete-only number of time steps. |
None
|
duration
|
float | None
|
Continuous-only simulation duration. |
None
|
dt
|
float | None
|
Continuous-only output sample spacing hint. When omitted, the
continuous path returns 201 samples including |
None
|
solver
|
Any
|
Continuous-only Diffrax solver override. |
None
|
adjoint
|
Any
|
Continuous-only Diffrax adjoint override. |
None
|
dt0
|
float | None
|
Continuous-only initial solver step size hint. |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
tuple[Array, Array, Array]: Simulation outputs |
Array
|
for DiscLTI, |
Array
|
|
tuple[Array, Array, Array]
|
and |
tuple[Array, Array, Array]
|
continuous output grid. |
Examples:
>>> import jax.numpy as jnp
>>> import contrax as cx
>>> sys_d = cx.dss(
... jnp.array([[1.0, 0.05], [0.0, 1.0]]),
... jnp.array([[0.0], [0.05]]),
... jnp.eye(2),
... jnp.zeros((2, 1)),
... dt=0.05,
... )
>>> K = cx.lqr(sys_d, jnp.eye(2), jnp.array([[1.0]])).K
>>> ts, xs, ys = cx.simulate(
... sys_d,
... jnp.array([1.0, 0.0]),
... lambda t, x: -K @ x,
... num_steps=60,
... )
Source code in contrax/sim.py
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 | |
step_response(sys: DiscLTI | ContLTI, *, num_steps: int | None = None, duration: float | None = None, input_index: int = 0, x0: Array | None = None, dt: float | None = None, solver: Any = None, adjoint: Any = None, dt0: float | None = None) -> tuple[Array, Array]
¶
Compute the unit-step response of an LTI system.
This is the standard control sanity-check response: apply a unit step on one input channel and return the output trajectory.
For MIMO systems, this first version follows a MATLAB-like workflow by
selecting one input channel at a time via input_index.
As an analysis helper, this is usually fine to differentiate with respect to system parameters when the step timing is fixed. It is not the right path for differentiating with respect to the discontinuity itself, such as a switching time or other hard event location.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sys
|
DiscLTI | ContLTI
|
Continuous or discrete-time system. |
required |
num_steps
|
int | None
|
Discrete-only number of time steps. |
None
|
duration
|
float | None
|
Continuous-only response duration. |
None
|
input_index
|
int
|
Input channel receiving the unit step. |
0
|
x0
|
Array | None
|
Optional initial state. Defaults to zeros. |
None
|
dt
|
float | None
|
Continuous-only output sample spacing hint. |
None
|
solver
|
Any
|
Continuous-only Diffrax solver override. |
None
|
adjoint
|
Any
|
Continuous-only Diffrax adjoint override. |
None
|
dt0
|
float | None
|
Continuous-only initial solver step size hint. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Array, Array]
|
tuple[Array, Array]: Sample times and output trajectory |