Source code for bellman_filter_dfsv.estimation

from collections.abc import Callable

import equinox as eqx
import jax
import jax.numpy as jnp
import optax
from jax import Array
from jaxtyping import Float

from .filters import BellmanFilter
from .smoothing import RBPSResult, run_rbps
from .types import DFSVParams, EMSufficientStats


[docs] def constrain_params_default(p_unc: DFSVParams) -> DFSVParams: """Default transformation from unconstrained to constrained parameters.""" # 1. Lambda: Unconstrained lambda_r = p_unc.lambda_r # 2. Phi: (-1, 1) via tanh Phi_f = jnp.tanh(p_unc.Phi_f) Phi_h = jnp.tanh(p_unc.Phi_h) # 3. Mu: Unconstrained mu = p_unc.mu # 4. Variances: Positive via softplus sigma2 = jax.nn.softplus(p_unc.sigma2) Q_h = jnp.diag(jax.nn.softplus(jnp.diag(p_unc.Q_h))) return DFSVParams(lambda_r, Phi_f, Phi_h, mu, sigma2, Q_h)
[docs] def unconstrain_params_default(p: DFSVParams) -> DFSVParams: """Default transformation from constrained to unconstrained parameters.""" def inv_softplus(y): return jnp.log(jnp.exp(y) - 1.0) def clip_tanh(x): """Clip tanh inputs to avoid infinities.""" return jnp.arctanh(jnp.clip(x, -0.999, 0.999)) return DFSVParams( lambda_r=p.lambda_r, Phi_f=clip_tanh(p.Phi_f), Phi_h=clip_tanh(p.Phi_h), mu=p.mu, sigma2=inv_softplus(p.sigma2), Q_h=jnp.linalg.cholesky(p.Q_h), )
[docs] def fit_mle( start_params: DFSVParams, observations: Float[Array, "T N"], learning_rate: float = 0.01, num_steps: int = 100, optimizer: optax.GradientTransformation | None = None, constrain_fn: Callable[[DFSVParams], DFSVParams] = constrain_params_default, unconstrain_fn: Callable[[DFSVParams], DFSVParams] = unconstrain_params_default, verbose: bool = True, ) -> tuple[DFSVParams, list[float]]: """Fits DFSV parameters using Maximum Likelihood Estimation (MLE). Args: start_params: Initial guess for parameters. observations: Observed data matrix (Time x N). learning_rate: Learning rate for Adam optimizer (default: 0.01). num_steps: Number of optimization steps. optimizer: Custom Optax optimizer (optional). If None, uses Adam. constrain_fn: Function to map unconstrained -> constrained params. unconstrain_fn: Function to map constrained -> unconstrained params. verbose: Whether to print progress. Returns: tuple: (optimized_params, loss_history) """ # 1. Setup Optimizer if optimizer is None: optimizer = optax.adam(learning_rate=learning_rate) # 2. Unconstrain Initial Parameters params_unc = unconstrain_fn(start_params) opt_state = optimizer.init(params_unc) # 3. Define Loss Function def loss_fn(p_u, obs): p_c = constrain_fn(p_u) # Create filter dynamically (zero cost in JIT) bf = BellmanFilter(p_c) return -bf.filter(obs).log_likelihood # 4. JIT Compile Step @jax.jit def step(p_u, opt_s, obs): loss, grads = jax.value_and_grad(loss_fn)(p_u, obs) updates, opt_s = optimizer.update(grads, opt_s, p_u) p_u = eqx.apply_updates(p_u, updates) return p_u, opt_s, loss # 5. Optimization Loop loss_history = [] current_params = params_unc if verbose: print(f"Starting MLE Optimization ({num_steps} steps)...") for i in range(num_steps): current_params, opt_state, loss = step(current_params, opt_state, observations) loss_val = float(loss) loss_history.append(loss_val) if verbose and (i % (num_steps // 10) == 0 or i == num_steps - 1): print(f"Step {i:4d} | Log-Likelihood: {-loss_val:.4f}") # 6. Return Constrained Parameters final_params = constrain_fn(current_params) return final_params, loss_history
[docs] def rbps_to_suffstats( rbps_result: RBPSResult, observations: Float[Array, "T N"], K: int ) -> EMSufficientStats: """ Convert RBPS samples to EM sufficient statistics by averaging over trajectories. """ M, T, K_dim = rbps_result.h_samples.shape assert K == K_dim E_f = jnp.mean(rbps_result.f_smooth_means, axis=0) m_mT = ( rbps_result.f_smooth_means[..., None] * rbps_result.f_smooth_means[..., None, :] ) E_ff_T = jnp.mean(rbps_result.f_smooth_covs + m_mT, axis=0) m_t = rbps_result.f_smooth_means[:, 1:, :, None] m_tm1 = rbps_result.f_smooth_means[:, :-1, None, :] m_lag_mT = m_t * m_tm1 E_f_lag_f_T = jnp.mean(rbps_result.f_smooth_lag1_covs + m_lag_mT, axis=0) E_h = jnp.mean(rbps_result.h_samples, axis=0) h_hT = rbps_result.h_samples[..., None] * rbps_result.h_samples[..., None, :] E_hh_T = jnp.mean(h_hT, axis=0) h_t = rbps_result.h_samples[:, 1:, :, None] h_tm1 = rbps_result.h_samples[:, :-1, None, :] E_h_lag_h_T = jnp.mean(h_t * h_tm1, axis=0) h_t2T = rbps_result.h_samples[:, 1:, :] exp_neg_h = jnp.exp(-h_t2T) E_exp_neg_h = jnp.mean(exp_neg_h, axis=0) f_lag_f_T_m = rbps_result.f_smooth_lag1_covs + m_lag_mT f_lag_f_diag = jnp.diagonal(f_lag_f_T_m, axis1=2, axis2=3) weighted_cross = exp_neg_h * f_lag_f_diag E_exp_neg_h_f_fprev_diag = jnp.mean(weighted_cross, axis=0) ff_T_m = rbps_result.f_smooth_covs + m_mT ff_T_prev_m = ff_T_m[:, :-1, :, :] ff_prev_diag = jnp.diagonal(ff_T_prev_m, axis1=2, axis2=3) weighted_prev_sq = exp_neg_h * ff_prev_diag E_exp_neg_h_fprev_sq = jnp.mean(weighted_prev_sq, axis=0) cross_y_f = observations.T @ E_f obs_sq_sum = jnp.sum(observations**2, axis=0) return EMSufficientStats( sum_r_f=cross_y_f, sum_f_f=jnp.sum(E_ff_T, axis=0), sum_r_r_diag=obs_sq_sum, sum_f_fprev=jnp.sum(E_f_lag_f_T, axis=0), sum_fprev_fprev=jnp.sum(E_ff_T[:-1], axis=0), sum_exp_neg_h=jnp.sum(E_exp_neg_h, axis=0), sum_exp_neg_h_f_fprev_diag=jnp.sum(E_exp_neg_h_f_fprev_diag, axis=0), sum_exp_neg_h_fprev_sq=jnp.sum(E_exp_neg_h_fprev_sq, axis=0), sum_h=jnp.sum(E_h[1:], axis=0), sum_hprev=jnp.sum(E_h[:-1], axis=0), sum_h_h=jnp.sum(E_hh_T[1:], axis=0), sum_h_hprev=jnp.sum(E_h_lag_h_T, axis=0), sum_hprev_hprev=jnp.sum(E_hh_T[:-1], axis=0), T=T, )
JITTER = 1e-6 MIN_VARIANCE = 1e-6
[docs] def update_lambda_r(stats: EMSufficientStats) -> Float[Array, "N K"]: sum_f_f_reg = stats.sum_f_f + jnp.eye(stats.sum_f_f.shape[0]) * JITTER return stats.sum_r_f @ jnp.linalg.inv(sum_f_f_reg)
[docs] def update_sigma2( stats: EMSufficientStats, lambda_r: Float[Array, "N K"] ) -> Float[Array, "N"]: T = stats.T quad_term = jnp.sum(lambda_r * (lambda_r @ stats.sum_f_f), axis=1) cross_term = jnp.sum(lambda_r * stats.sum_r_f, axis=1) sigma2 = (stats.sum_r_r_diag - 2 * cross_term + quad_term) / T return jnp.maximum(sigma2, MIN_VARIANCE)
[docs] def update_Phi_h( stats: EMSufficientStats, mu: Float[Array, "K"] ) -> Float[Array, "K K"]: T_minus_1 = stats.T - 1 K = mu.shape[0] sum_a_b = ( jnp.diag(stats.sum_h_hprev) - mu * stats.sum_hprev - mu * stats.sum_h + T_minus_1 * mu**2 ) sum_b_sq = ( jnp.diag(stats.sum_hprev_hprev) - 2 * mu * stats.sum_hprev + T_minus_1 * mu**2 ) phi_h_diag = sum_a_b / jnp.maximum(sum_b_sq, JITTER) phi_h_diag = jnp.clip(phi_h_diag, -0.999, 0.999) return jnp.diag(phi_h_diag)
[docs] def update_mu( stats: EMSufficientStats, Phi_h: Float[Array, "K K"] ) -> Float[Array, "K"]: T_minus_1 = stats.T - 1 phi_h_diag = jnp.diag(Phi_h) numerator = stats.sum_h - phi_h_diag * stats.sum_hprev denominator = T_minus_1 * (1.0 - phi_h_diag) return ( numerator / jnp.maximum(jnp.abs(denominator), JITTER) * jnp.sign(denominator + 1e-10) )
[docs] def update_Q_h( stats: EMSufficientStats, mu: Float[Array, "K"], Phi_h: Float[Array, "K K"] ) -> Float[Array, "K K"]: T_minus_1 = stats.T - 1 phi_h_diag = jnp.diag(Phi_h) sum_a_sq = jnp.diag(stats.sum_h_h) - 2 * mu * stats.sum_h + T_minus_1 * mu**2 sum_b_sq = ( jnp.diag(stats.sum_hprev_hprev) - 2 * mu * stats.sum_hprev + T_minus_1 * mu**2 ) sum_a_b = ( jnp.diag(stats.sum_h_hprev) - mu * (stats.sum_h + stats.sum_hprev) + T_minus_1 * mu**2 ) S_eta = sum_a_sq - 2 * phi_h_diag * sum_a_b + phi_h_diag**2 * sum_b_sq q_h_diag = S_eta / T_minus_1 q_h_diag = jnp.maximum(q_h_diag, MIN_VARIANCE) return jnp.diag(q_h_diag)
[docs] def update_Phi_f(stats: EMSufficientStats) -> Float[Array, "K K"]: phi_f_diag = stats.sum_exp_neg_h_f_fprev_diag / jnp.maximum( stats.sum_exp_neg_h_fprev_sq, JITTER ) phi_f_diag = jnp.clip(phi_f_diag, -0.999, 0.999) return jnp.diag(phi_f_diag)
[docs] def m_step( stats: EMSufficientStats, n_mu_phi_iters: int = 3 ) -> tuple[Float[Array, "N K"], ...]: lambda_r = update_lambda_r(stats) sigma2 = update_sigma2(stats, lambda_r) Phi_f = update_Phi_f(stats) K = stats.sum_h.shape[0] mu = stats.sum_h / (stats.T - 1) Phi_h = jnp.eye(K) * 0.9 for _ in range(n_mu_phi_iters): Phi_h = update_Phi_h(stats, mu) mu = update_mu(stats, Phi_h) Q_h = update_Q_h(stats, mu, Phi_h) return lambda_r, sigma2, Phi_f, mu, Phi_h, Q_h
[docs] def fit_em( observations: Float[Array, "T N"], init_params: DFSVParams, num_particles: int = 200, num_trajectories: int = 20, max_iters: int = 50, verbose: bool = True, ) -> tuple[DFSVParams, list[DFSVParams]]: """ Fit DFSV model using Expectation-Maximization with RBPS. """ current_params = init_params history = [] K = init_params.lambda_r.shape[1] jit_rbps = jax.jit( lambda p: run_rbps(p, observations, num_particles, num_trajectories) ) if verbose: print( f"Starting EM (RBPS) | Particles: {num_particles} | Trajs: {num_trajectories}" ) for i in range(max_iters): # TODO: Add early stopping if converged rbps_res = jit_rbps(current_params) stats = rbps_to_suffstats(rbps_res, observations, K) lambda_r, sigma2, Phi_f, mu, Phi_h, Q_h = m_step(stats) current_params = DFSVParams( lambda_r=lambda_r, Phi_f=Phi_f, Phi_h=Phi_h, mu=mu, sigma2=sigma2, Q_h=Q_h, ) history.append(current_params) if verbose: print( f"Iter {i:3d} | " f"Lambda[0,0]: {lambda_r[0, 0]:.4f} | " f"Sigma2[0]: {sigma2[0]:.4f}" ) return current_params, history