Differentiable LQR¶
This tutorial shows one of Contrax's defining workflows: optimizing controller weights inside an ordinary JAX objective.
The idea is simple: treat controller design as part of a larger JAX objective.
Instead of choosing Q and R offline and then freezing them, make them
parameters in a differentiable loop.
Runnable script: examples/differentiable_lqr.py
This tutorial mainly spans the Control API and Simulation API. If you want the shorter recipe version, see Tune LQR with gradients.
The Workflow¶
import numpy as np
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import contrax as cx
A = jnp.array([[1.0, 0.05], [0.0, 1.0]])
B = jnp.array([[0.0], [0.05]])
C = jnp.eye(2)
D = jnp.zeros((2, 1))
SYS = cx.dss(A, B, C, D, dt=0.05)
X0 = jnp.array([1.0, 0.0])
def closed_loop_cost(log_q_diag, log_r):
q_diag = jnp.exp(log_q_diag)
r = jnp.exp(log_r)[None, None]
K = cx.lqr(SYS, jnp.diag(q_diag), r).K
_, xs, _ = cx.simulate(SYS, X0, lambda t, x: -K @ x, num_steps=80)
control_energy = jnp.sum((xs[:-1] @ K.T) ** 2)
return jnp.sum(xs**2) + 1e-2 * control_energy
objective_and_grad = jax.jit(jax.value_and_grad(closed_loop_cost, argnums=(0, 1)))
dq, dlog_r = objective_and_grad(
jnp.zeros(2), jnp.array(0.0)
)[1]
The important part is not the syntax. It is the workflow shape: lqr() and
simulate() inside an ordinary compiled JAX objective with gradients that stay
finite and usable.
At the control level, the script is optimizing the discrete LQR workflow
with controller-design weights
and gain map
The rollout is then evaluated under a separate closed-loop objective of the form
where the dependence on \(\theta_Q\) and \(\theta_R\) enters through \(K(\theta_Q, \theta_R)\) and the resulting closed-loop trajectory.
What The Script Prints¶
Running examples/differentiable_lqr.py prints the initial and final objective,
plus the tuned weights and gain:
Differentiable LQR tuning
initial cost = 29.906683
final cost = 22.984589
final Q diag = [4.30274642 4.1384593 ]
final R = 0.056159
final K = [[6.90505027 7.89914659]]
The exact tuned weights depend on the optimization loop configuration. The important check is simpler: the cost goes down and the gradients stay finite.
Why This Works¶
Three pieces matter:
lqr()on discrete systems goes throughdare(), which has an implicit-differentiation custom VJP rather than a backward pass that unrolls the forward solver iterations.simulate()is a pure JAX closed-loop scan for discrete systems.- the system object is a pytree-friendly Equinox module rather than a class with hidden runtime behavior.
That combination keeps controller design, simulation, and optimization in the same JAX world.
What the Script Demonstrates¶
The runnable example in examples/differentiable_lqr.py performs a small
gradient-descent loop over log-parameterized Q and R. The optimization step
itself is ordinary JAX: call the compiled value-and-gradient function, update
the parameters, and repeat.
The script prints:
- initial cost
- final cost
- the final tuned
Qdiagonal - the final tuned
R - the resulting feedback gain
The exact numbers are not the main point. The point is that the optimization loop is ordinary JAX code rather than a separate control-design procedure.
From a control perspective, the important object is still the closed-loop matrix
but the workflow treats K as the result of a differentiable design map
\((Q, R) \mapsto K\).
Validate The Result¶
For this workflow, the first checks are:
- the final cost should be lower than the initial cost
- the tuned
QandRshould remain positive because they are log-parameterized - the resulting feedback gain should be finite
The runnable script asserts the cost decrease directly, and the printed values give a fast sanity check that the optimization stayed in a plausible regime.
Scope Of This Tutorial¶
This tutorial intentionally stays on the discrete LTI path. It is the cleanest
path for differentiating through controller design: lqr() uses dare(), and
the closed-loop rollout uses the pure JAX discrete simulate() scan.
Continuous-time LQR and continuous simulation are available through care() and
the Diffrax-backed simulate() path, but they are a different workflow with
different solver and adjoint choices. Pole placement is also available for
design-time feedback design, but it is not the right primitive for gradient
tuning of Q and R.
Where to Go Next¶
- Getting started for the fastest route into the library
- Control API for
lqr,dare, andstate_feedback - Simulation API for discrete closed-loop rollout semantics
- Linearize, LQR, simulate for the classic control flow
- JAX-native workflows for the broader pattern map
- Riccati solvers for the numerical details