Source code for bellman_filter_dfsv.smoothing

import equinox as eqx
import jax
import jax.numpy as jnp
from jax import Array
from jaxtyping import Float

from ._bellman_math import (
    K_CHOL_JITTER,
    LOG_DIAG_FLOOR,
    WOODBURY_DIAG_JITTER,
    invert_info_matrix,
    predict_info_step,
)
from .types import BIFState, DFSVParams, RBParticleState, RBPSResult


[docs] class SmootherResult(eqx.Module): """Result container for RTS smoother.""" smoothed_means: Float[Array, "T 2K"] smoothed_covs: Float[Array, "T 2K 2K"] smoothed_lag1_covs: Float[Array, "T 2K 2K"]
[docs] def rts_smoother( params: DFSVParams, filter_means: Float[Array, "T 2K"], filter_infos: Float[Array, "T 2K 2K"], ) -> SmootherResult: """Runs the Rauch-Tung-Striebel (RTS) smoother adapted for information filter results.""" T, state_dim = filter_means.shape K = params.lambda_r.shape[1] # Pre-calculate covariances from information matrices vmap_invert = jax.vmap(invert_info_matrix) filter_covs = vmap_invert(filter_infos) # Re-run prediction step to get predicted statistics # Note: We could cache this from the forward pass if memory allows, # but re-computing is often cheaper than storing. filtered_states_bif = BIFState(mean=filter_means, info=filter_infos) predicted_states = jax.vmap(lambda s: predict_info_step(params, s))( filtered_states_bif ) pred_means = predicted_states.mean pred_infos = predicted_states.info pred_covs = vmap_invert(pred_infos) F = jnp.block( [[params.Phi_f, jnp.zeros((K, K))], [jnp.zeros((K, K)), params.Phi_h]] ) # Initialize backward pass with the final state init_carry = (filter_means[-1], filter_covs[-1]) # Inputs for the backward scan (reversed in time implicitly by scan behavior or indices) # We need t=0...T-2 for the backward steps. xs = ( filter_means[:-1], filter_covs[:-1], pred_means[:-1], pred_infos[:-1], pred_covs[:-1], ) def backward_step(carry, x): smooth_mean_tp1, smooth_cov_tp1 = carry filt_mean_t, filt_cov_t, pred_mean_tp1, pred_info_tp1, pred_cov_tp1 = x # RTS Gain: J_t = P_{t|t} F^T P_{t+1|t}^-1 # In info form: J_t = P_{t|t} F^T \Omega_{t+1|t} (approx, but we have P and Info) # Actually J_t = P_{t|t} F^T P_{t+1|t}^-1 # P_{t+1|t}^-1 is exactly pred_info_tp1. J_t = filt_cov_t @ F.T @ pred_info_tp1 smooth_mean_t = filt_mean_t + J_t @ (smooth_mean_tp1 - pred_mean_tp1) cov_diff = smooth_cov_tp1 - pred_cov_tp1 smooth_cov_t = filt_cov_t + J_t @ cov_diff @ J_t.T smooth_cov_t = 0.5 * (smooth_cov_t + smooth_cov_t.T) lag1_cov = smooth_cov_tp1 @ J_t.T return (smooth_mean_t, smooth_cov_t), (smooth_mean_t, smooth_cov_t, lag1_cov) _, (means_rev, covs_rev, lag1_rev) = jax.lax.scan( backward_step, init_carry, xs, reverse=True ) full_means = jnp.concatenate([means_rev, init_carry[0][None, :]], axis=0) full_covs = jnp.concatenate([covs_rev, init_carry[1][None, :, :]], axis=0) # Lag-1 covariances are for t=0...T-1. # The last element is somewhat undefined or can be padded. pad_lag1 = jnp.zeros((1, state_dim, state_dim)) full_lag1 = jnp.concatenate([lag1_rev, pad_lag1], axis=0) return SmootherResult( smoothed_means=full_means, smoothed_covs=full_covs, smoothed_lag1_covs=full_lag1, )
# --- RBPS Logic --- def _predict_rbpf( params: DFSVParams, state: RBParticleState, key: jax.Array ) -> RBParticleState: """Prediction step for RBPF.""" K = params.lambda_r.shape[1] P = state.h_particles.shape[1] # 1. Propagate h particles # h_{t+1} = mu + Phi_h(h_t - mu) + eta_t mu_col = params.mu[:, None] h_dev = state.h_particles - mu_col h_pred_mean = mu_col + params.Phi_h @ h_dev # Sample noise L_Qh = jnp.linalg.cholesky(params.Q_h + K_CHOL_JITTER * jnp.eye(K)) noise = jax.random.normal(key, shape=(K, P)) h_next = h_pred_mean + L_Qh @ noise # 2. Propagate f statistics (Kalman Predict) # f_{t+1} = Phi_f f_t + eps_t # E[f_{t+1}] = Phi_f E[f_t] # Var(f_{t+1}) = Phi_f Var(f_t) Phi_f' + Q_f(h_{t+1}) # Batched matrix multiplication for means f_pred_means = params.Phi_f @ state.f_means # Batched covariance update # Need to handle (K, K, P) batching manually or via vmap def predict_cov(cov_p, h_next_p): Q_f = jnp.diag(jnp.exp(h_next_p)) return params.Phi_f @ cov_p @ params.Phi_f.T + Q_f f_pred_covs = jax.vmap(predict_cov, in_axes=(2, 1), out_axes=2)( state.f_covs, h_next ) return RBParticleState( h_particles=h_next, f_means=f_pred_means, f_covs=f_pred_covs, log_weights=state.log_weights, ) def _rbpf_obs_cache(params: DFSVParams): inv_sigma2 = 1.0 / (params.sigma2 + WOODBURY_DIAG_JITTER) lambda_scaled = params.lambda_r * inv_sigma2[:, None] B = params.lambda_r.T @ lambda_scaled logdet_D = jnp.sum(jnp.log(jnp.maximum(params.sigma2, LOG_DIAG_FLOOR))) return inv_sigma2, B, logdet_D def _update_rbpf( params: DFSVParams, state: RBParticleState, observation: Float[Array, "N"] ) -> tuple[RBParticleState, Float[Array, "P"]]: """Update step for RBPF.""" N = params.lambda_r.shape[0] K = params.lambda_r.shape[1] inv_sigma2, B, logdet_D = _rbpf_obs_cache(params) def kalman_update(f_mean, f_cov): v = observation - (params.lambda_r @ f_mean) chol_P = jnp.linalg.cholesky(f_cov + K_CHOL_JITTER * jnp.eye(K)) P_inv = jax.scipy.linalg.cho_solve((chol_P, True), jnp.eye(K)) M = 0.5 * (P_inv + B + (P_inv + B).T) chol_M = jnp.linalg.cholesky(M + K_CHOL_JITTER * jnp.eye(K)) Dinv_v = v * inv_sigma2 u = params.lambda_r.T @ Dinv_v Sigma = jax.scipy.linalg.cho_solve((chol_M, True), jnp.eye(K)) f_new_mean = f_mean + Sigma @ u quad = jnp.dot(v, Dinv_v) - jnp.dot( u, jax.scipy.linalg.cho_solve((chol_M, True), u) ) logdet_P = 2.0 * jnp.sum(jnp.log(jnp.maximum(jnp.diag(chol_P), LOG_DIAG_FLOOR))) logdet_M = 2.0 * jnp.sum(jnp.log(jnp.maximum(jnp.diag(chol_M), LOG_DIAG_FLOOR))) logdet_S = logdet_D + logdet_P + logdet_M log_lik = -0.5 * (N * jnp.log(2.0 * jnp.pi) + logdet_S + quad) return f_new_mean, Sigma, log_lik f_means_new, f_covs_new, incremental_log_liks = jax.vmap( kalman_update, in_axes=(1, 2), out_axes=(1, 2, 0) )(state.f_means, state.f_covs) new_log_weights = state.log_weights + incremental_log_liks return RBParticleState( h_particles=state.h_particles, f_means=f_means_new, f_covs=f_covs_new, log_weights=new_log_weights, ), incremental_log_liks
[docs] def systematic_resample_indices(key, log_weights, num_particles): # Stabilize weights max_log_w = jnp.max(log_weights) lw_norm = ( log_weights - max_log_w - jax.scipy.special.logsumexp(log_weights - max_log_w) ) w = jnp.exp(lw_norm) u = jax.random.uniform(key, (), minval=0.0, maxval=1.0 / num_particles) points = u + jnp.arange(num_particles) / num_particles cum_w = jnp.cumsum(w) indices = jnp.searchsorted(cum_w, points) return jnp.clip(indices, 0, num_particles - 1)
[docs] def run_rbps( params: DFSVParams, observations: Float[Array, "T N"], num_particles: int = 100, num_trajectories: int = 20, seed: int = 42, ) -> RBPSResult: T, N = observations.shape K = params.lambda_r.shape[1] key = jax.random.PRNGKey(seed) # --- Forward RBPF --- key, init_key = jax.random.split(key) # Init h kron_prod = jnp.kron(params.Phi_h, params.Phi_h) vec_Qh = params.Q_h.flatten() vec_Ph = jnp.linalg.solve(jnp.eye(K * K) - kron_prod, vec_Qh) P_h = vec_Ph.reshape(K, K) L_h = jnp.linalg.cholesky(P_h + K_CHOL_JITTER * jnp.eye(K)) h0 = params.mu[:, None] + L_h @ jax.random.normal(init_key, (K, num_particles)) # Init f (unconditional) f0_means = jnp.zeros((K, num_particles)) f0_covs = jnp.tile(jnp.eye(K)[:, :, None], (1, 1, num_particles)) init_state = RBParticleState( h0, f0_means, f0_covs, jnp.full((num_particles,), -jnp.log(num_particles)) ) def scan_step(carry, obs_t): state_prev, key_curr = carry key_curr, pred_key, res_key = jax.random.split(key_curr, 3) # Predict & Update state_pred = _predict_rbpf(params, state_prev, pred_key) state_upd, _ = _update_rbpf(params, state_pred, obs_t) # Save state before resampling for backward pass saved_state = state_upd # Always resample for now (simplifies code, standard PF) idx = systematic_resample_indices(res_key, state_upd.log_weights, num_particles) h_new = state_upd.h_particles[:, idx] f_m_new = state_upd.f_means[:, idx] f_c_new = state_upd.f_covs[:, :, idx] lw_new = jnp.full((num_particles,), -jnp.log(num_particles)) state_resampled = RBParticleState(h_new, f_m_new, f_c_new, lw_new) return (state_resampled, key_curr), saved_state _, history = jax.lax.scan(scan_step, (init_state, key), observations) # --- Backward FFBS --- final_weights = history.log_weights[-1] key, idx_key = jax.random.split(key) final_indices = jax.random.categorical( idx_key, final_weights, shape=(num_trajectories,) ) # Use two-step indexing to avoid JAX/NumPy advanced indexing behavior h_T = history.h_particles[-1][:, final_indices] # (K, M) def backward_step(h_next_batch, t): particles_t = history.h_particles[t] # (K, P) log_w_t = history.log_weights[t] # (P) # Transition density log p(h_{t+1}|h_t) # h_{t+1} ~ N(mu + Phi(h_t - mu), Q) h_n = h_next_batch[:, None, :] # (K, 1, M) p_t = particles_t[:, :, None] # (K, P, 1) mu_col = params.mu[:, None, None] dev = p_t - mu_col mean = mu_col + jnp.tensordot(params.Phi_h, dev, axes=(1, 0)) # (K, P, 1) res = h_n - mean # (K, P, M) q_diag = jnp.diag(params.Q_h)[:, None, None] log_trans = -0.5 * jnp.sum( (res**2) / (q_diag + WOODBURY_DIAG_JITTER), axis=0 ) # (P, M) log_wb = log_w_t[:, None] + log_trans # Stabilize weights max_log_wb = jnp.max(log_wb, axis=0, keepdims=True) log_wb = log_wb - max_log_wb # Sample indices rngs = jax.random.split(jax.random.fold_in(key, t), num_trajectories) indices = jax.vmap(lambda lw, k: jax.random.categorical(k, lw), in_axes=(1, 0))( log_wb, rngs ) h_t_sampled = particles_t[:, indices] return h_t_sampled, h_t_sampled # JAX scan with reverse=True returns stacked outputs in input index order. # Explicit reversal is required to restore time-ascending order (0 to T-2). _, h_samples_rev = jax.lax.scan(backward_step, h_T, jnp.arange(T - 1)[::-1]) h_samples = jnp.concatenate([h_samples_rev[::-1], h_T[None, ...]], axis=0) h_samples = jnp.transpose(h_samples, (2, 0, 1)) # (M, T, K) # --- Conditional Smoother --- def conditional_smooth(h_traj): return _conditional_rts_smoother(params, observations, h_traj) f_means, f_covs, f_lag1_covs = jax.vmap(conditional_smooth)(h_samples) return RBPSResult(h_samples, f_means, f_covs, f_lag1_covs)
def _conditional_rts_smoother(params, observations, h_traj): T, N = observations.shape K = params.lambda_r.shape[1] inv_sigma2, B, _ = _rbpf_obs_cache(params) # Forward Filter def filter_step(carry, inp): m, P = carry y, h = inp Q = jnp.diag(jnp.exp(h)) m_pred = params.Phi_f @ m P_pred = params.Phi_f @ P @ params.Phi_f.T + Q v = y - (params.lambda_r @ m_pred) chol_P = jnp.linalg.cholesky(P_pred + K_CHOL_JITTER * jnp.eye(K)) P_inv = jax.scipy.linalg.cho_solve((chol_P, True), jnp.eye(K)) M = 0.5 * (P_inv + B + (P_inv + B).T) chol_M = jnp.linalg.cholesky(M + K_CHOL_JITTER * jnp.eye(K)) Dinv_v = v * inv_sigma2 u = params.lambda_r.T @ Dinv_v P_upd = jax.scipy.linalg.cho_solve((chol_M, True), jnp.eye(K)) m_upd = m_pred + P_upd @ u return (m_upd, P_upd), (m_pred, P_pred, m_upd, P_upd) m0 = jnp.zeros(K) P0 = jnp.eye(K) _, (m_preds, P_preds, m_filts, P_filts) = jax.lax.scan( filter_step, (m0, P0), (observations, h_traj) ) # Backward Smoother def smooth_step(carry, inp): m_next, P_next = carry m_p, P_p, m_f, P_f = inp # J = P_f Phi' P_p^-1 L_P = jnp.linalg.cholesky(P_p + K_CHOL_JITTER * jnp.eye(K)) # solve P_p J' = Phi P_f J_T = jax.scipy.linalg.cho_solve((L_P, True), params.Phi_f @ P_f) J = J_T.T m_smooth = m_f + J @ (m_next - m_p) P_smooth = P_f + J @ (P_next - P_p) @ J.T cov_lag1 = P_next @ J.T # P_{t+1, t | T} # Symmetrize P_smooth = 0.5 * (P_smooth + P_smooth.T) return (m_smooth, P_smooth), (m_smooth, P_smooth, cov_lag1) last_m = m_filts[-1] last_P = P_filts[-1] xs = (m_preds[1:], P_preds[1:], m_filts[:-1], P_filts[:-1]) # JAX scan(reverse=True) returns outputs in ascending input index order. # No reversal is needed as the inputs are already aligned. _, (ms, Ps, Ps_lag1) = jax.lax.scan(smooth_step, (last_m, last_P), xs, reverse=True) f_smooth_means = jnp.concatenate([ms, last_m[None, :]], axis=0) f_smooth_covs = jnp.concatenate([Ps, last_P[None, :, :]], axis=0) f_smooth_lag1_covs = Ps_lag1 return f_smooth_means, f_smooth_covs, f_smooth_lag1_covs