Source code for bellman_filter_dfsv.filters

import equinox as eqx
import jax
import jax.numpy as jnp
from jax import Array
from jaxtyping import Float

from ._bellman_math import invert_info_matrix, predict_info_step, update_info_step
from ._particle_math import (
    calculate_ess,
    compute_log_likelihood_particles,
    initialize_particles,
    predict_particles,
    systematic_resample,
)
from .types import (
    BIFState,
    DFSVParams,
    FilterResult,
    ParticleFilterResult,
    ParticleState,
)


[docs] class BellmanFilter(eqx.Module): """Bellman Information Filter for DFSV models.""" params: DFSVParams
[docs] def __init__(self, params: DFSVParams): """Initializes the filter with model parameters.""" self.params = params
[docs] def initialize(self) -> BIFState: """Initializes the filter state from model parameters.""" K = self.params.lambda_r.shape[1] f0 = jnp.zeros(K) h0 = self.params.mu mean = jnp.concatenate([f0, h0]) P_f = jnp.eye(K) kron_prod = jnp.kron(self.params.Phi_h, self.params.Phi_h) I_k2 = jnp.eye(K * K) vec_Qh = self.params.Q_h.flatten() vec_Ph = jnp.linalg.solve(I_k2 - kron_prod, vec_Qh) P_h = vec_Ph.reshape(K, K) P_0 = jnp.block([[P_f, jnp.zeros((K, K))], [jnp.zeros((K, K)), P_h]]) Omega_0 = jnp.linalg.inv(P_0) return BIFState(mean=mean, info=Omega_0)
[docs] def filter(self, observations: Float[Array, "T N"]) -> FilterResult: """Runs the filter over a sequence of observations.""" init_state = self.initialize() # Initialize accumulator as JAX array to ensure type consistency init_carry = (init_state, jnp.array(0.0)) def scan_step(carry, obs_t): state_prev, log_lik_prev = carry state_pred = predict_info_step(self.params, state_prev) state_updated, log_lik_contrib = update_info_step( self.params, state_pred, obs_t ) new_log_lik = log_lik_prev + log_lik_contrib new_carry = (state_updated, new_log_lik) return new_carry, (state_updated.mean, state_updated.info) final_carry, (means, infos) = jax.lax.scan(scan_step, init_carry, observations) total_log_lik = final_carry[1] return FilterResult(means=means, infos=infos, log_likelihood=total_log_lik)
[docs] def smooth_state( self, mean: Float[Array, "2K"], info: Float[Array, "2K 2K"] ) -> tuple[Float[Array, "2K"], Float[Array, "2K 2K"]]: """Converts information state (α, Ω) to covariance state (α, P).""" cov = invert_info_matrix(info) return mean, cov
[docs] class ParticleFilter(eqx.Module): """Particle Filter (SISR) for DFSV models.""" params: DFSVParams num_particles: int resample_threshold_frac: float seed: int
[docs] def __init__( self, params: DFSVParams, num_particles: int = 1000, resample_threshold_frac: float = 0.5, seed: int = 42, ): """Initializes the particle filter.""" self.params = params self.num_particles = num_particles self.resample_threshold_frac = resample_threshold_frac self.seed = seed
[docs] def filter(self, observations: Float[Array, "T N"]) -> ParticleFilterResult: """Runs the particle filter over a sequence of observations.""" key = jax.random.PRNGKey(self.seed) # Initialize init_state = initialize_particles(self.params, self.num_particles, key) init_ll = jnp.array(0.0) # Scan carry: (state, key, log_lik_accum) init_carry = (init_state, key, init_ll) def scan_step(carry, obs_t): state_prev, key_curr, ll_accum = carry # 1. Predict state_pred, key_next = predict_particles(self.params, state_prev, key_curr) # 2. Update weights log_liks = compute_log_likelihood_particles( self.params, state_pred.particles, obs_t ) unnorm_log_weights = state_pred.log_weights + log_liks # Update log-likelihood ll_step = jax.scipy.special.logsumexp(unnorm_log_weights) ll_next = ll_accum + ll_step # 3. Resample (if needed) # Normalize for ESS calculation ess = calculate_ess(unnorm_log_weights) threshold = self.num_particles * self.resample_threshold_frac def do_resample(args): k, lw, p = args p_new, lw_new = systematic_resample(k, lw, p) return p_new, lw_new def no_resample(args): _, lw, p = args # Just normalize weights for next step consistency if desired, # but standard PF carries unnormalized until resampling. # However, for stability we usually work with normalized weights or # carry log-weights. Here we keep log-weights. # Wait, systematic_resample returns uniform weights. # If we don't resample, we should keep the current weights. return p, lw # We need to pass unnorm_log_weights to the next step, # BUT systematic_resample resets them to uniform (-log N). # So we need to handle the weight transition carefully. particles_next, log_weights_next = jax.lax.cond( ess < threshold, do_resample, no_resample, (key_next, unnorm_log_weights, state_pred.particles), ) state_next = ParticleState( particles=particles_next, log_weights=log_weights_next ) # Calculate estimates for output (using weights BEFORE resampling, or AFTER? # Usually after update step but before resampling is 'filtered' estimate) # Let's use the weights derived from observation t # Normalize for estimation lw_norm = unnorm_log_weights - jax.scipy.special.logsumexp( unnorm_log_weights ) w_norm = jnp.exp(lw_norm) mean_est = jnp.sum(state_pred.particles * w_norm[None, :], axis=1) # Covariance diff = state_pred.particles - mean_est[:, None] cov_est = (diff * w_norm[None, :]) @ diff.T output = (mean_est, cov_est) new_carry = (state_next, key_next, ll_next) return new_carry, output final_carry, (means, covs) = jax.lax.scan(scan_step, init_carry, observations) return ParticleFilterResult( means=means, covs=covs, log_likelihood=final_carry[2] )