Skip to content

How To Batch Controller Design Over Operating Points

This guide shows how to use vmap over linearization points for gain scheduling or operating-point sweeps.

Complete Working Code

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import contrax as cx


def pendulum(x, u):
    return jnp.array([x[1], -jnp.sin(x[0]) + u[0]])


def output(x, u):
    return x


def design(x_eq, u_eq):
    sys_c = cx.linearize_ss(pendulum, x_eq, u_eq, output=output)
    sys_d = cx.c2d(sys_c, dt=0.05)
    return cx.lqr(sys_d, jnp.eye(2), jnp.ones((1, 1))).K


x_eqs = jnp.array([[0.0, 0.0], [0.1, 0.0], [-0.1, 0.0]])
u_eqs = jnp.zeros((3, 1))
Ks = jax.jit(jax.vmap(design))(x_eqs, u_eqs)

Why This Recipe Works

  • linearize_ss() is designed to compose with vmap
  • DiscLTI.dt stays as an array leaf rather than static metadata
  • the design function keeps one fixed-shape workflow: linearize_ss -> c2d -> lqr

Key Choices

  • Keep every operating point the same state and input shape.
  • Enable float64 before calling c2d() or lqr().
  • Start with a small grid of operating points before building a full gain-scheduling workflow.

Check

Verify:

  • Ks.shape == (num_points, m, n)
  • the gains are finite
  • nearby operating points produce sensibly related gains