API Documentation
This page documents the API with functional architecture using Equinox.
Types
- class bellman_filter_dfsv.types.BIFState(mean: Float[Array, '2K'], info: Float[Array, '2K 2K'])[source]
Bases:
NamedTupleState container for the Bellman Information Filter.
- mean
Information state mean vector α_{t|t} (2K,).
- Type:
jaxtyping.Float[Array, ‘2K’]
- info
Information matrix Ω_{t|t} (2K, 2K).
- Type:
jaxtyping.Float[Array, ‘2K 2K’]
- info: Float[Array, '2K 2K']
Alias for field number 1
- mean: Float[Array, '2K']
Alias for field number 0
- class bellman_filter_dfsv.types.DFSVParams(lambda_r: Float[Array, 'N K'], Phi_f: Float[Array, 'K K'], Phi_h: Float[Array, 'K K'], mu: Float[Array, 'K'], sigma2: Float[Array, 'N'], Q_h: Float[Array, 'K K'])[source]
Bases:
NamedTupleParameters for the Dynamic Factor Stochastic Volatility Model.
- lambda_r
Factor loadings Λ (N, K).
- Type:
jaxtyping.Float[Array, ‘N K’]
- Phi_f
Factor autoregression matrix Φ_f (K, K).
- Type:
jaxtyping.Float[Array, ‘K K’]
- Phi_h
Log-volatility autoregression matrix Φ_h (K, K).
- Type:
jaxtyping.Float[Array, ‘K K’]
- mu
Long-run mean of log-volatilities μ (K,).
- Type:
jaxtyping.Float[Array, ‘K’]
- sigma2
Idiosyncratic variances diag(Σ_ε) (N,).
- Type:
jaxtyping.Float[Array, ‘N’]
- Q_h
Log-volatility innovation covariance Q_h (K, K).
- Type:
jaxtyping.Float[Array, ‘K K’]
- Phi_f: Float[Array, 'K K']
Alias for field number 1
- Phi_h: Float[Array, 'K K']
Alias for field number 2
- Q_h: Float[Array, 'K K']
Alias for field number 5
- lambda_r: Float[Array, 'N K']
Alias for field number 0
- mu: Float[Array, 'K']
Alias for field number 3
- sigma2: Float[Array, 'N']
Alias for field number 4
- class bellman_filter_dfsv.types.EMSufficientStats(sum_r_f: Float[Array, 'N K'], sum_f_f: Float[Array, 'K K'], sum_r_r_diag: Float[Array, 'N'], sum_f_fprev: Float[Array, 'K K'], sum_fprev_fprev: Float[Array, 'K K'], sum_exp_neg_h: Float[Array, 'K'], sum_exp_neg_h_f_fprev_diag: Float[Array, 'K'], sum_exp_neg_h_fprev_sq: Float[Array, 'K'], sum_h: Float[Array, 'K'], sum_hprev: Float[Array, 'K'], sum_h_h: Float[Array, 'K K'], sum_h_hprev: Float[Array, 'K K'], sum_hprev_hprev: Float[Array, 'K K'], T: int)[source]
Bases:
NamedTupleSufficient statistics for EM Algorithm (M-step).
- sum_exp_neg_h: Float[Array, 'K']
Alias for field number 5
- sum_exp_neg_h_f_fprev_diag: Float[Array, 'K']
Alias for field number 6
- sum_exp_neg_h_fprev_sq: Float[Array, 'K']
Alias for field number 7
- sum_f_f: Float[Array, 'K K']
Alias for field number 1
- sum_f_fprev: Float[Array, 'K K']
Alias for field number 3
- sum_fprev_fprev: Float[Array, 'K K']
Alias for field number 4
- sum_h: Float[Array, 'K']
Alias for field number 8
- sum_h_h: Float[Array, 'K K']
Alias for field number 10
- sum_h_hprev: Float[Array, 'K K']
Alias for field number 11
- sum_hprev: Float[Array, 'K']
Alias for field number 9
- sum_hprev_hprev: Float[Array, 'K K']
Alias for field number 12
- sum_r_f: Float[Array, 'N K']
Alias for field number 0
- sum_r_r_diag: Float[Array, 'N']
Alias for field number 2
- class bellman_filter_dfsv.types.FilterResult(means: Float[Array, 'T 2K'], infos: Float[Array, 'T 2K 2K'], log_likelihood: Float[Array, ''])[source]
Bases:
NamedTupleResult container for a full filter run.
- means
Filtered state means α_{t|t} (T, 2K).
- Type:
jaxtyping.Float[Array, ‘T 2K’]
- infos
Filtered information matrices Ω_{t|t} (T, 2K, 2K).
- Type:
jaxtyping.Float[Array, ‘T 2K 2K’]
- log_likelihood
Total log-likelihood scalar.
- Type:
jaxtyping.Float[Array, ‘’]
- infos: Float[Array, 'T 2K 2K']
Alias for field number 1
- log_likelihood: Float[Array, '']
Alias for field number 2
- means: Float[Array, 'T 2K']
Alias for field number 0
- class bellman_filter_dfsv.types.ParticleFilterResult(means: Float[Array, 'T 2K'], covs: Float[Array, 'T 2K 2K'], log_likelihood: Float[Array, ''])[source]
Bases:
NamedTupleResult container for Particle Filter run.
- means
Weighted mean estimates (T, 2K).
- Type:
jaxtyping.Float[Array, ‘T 2K’]
- covs
Weighted covariance estimates (T, 2K, 2K).
- Type:
jaxtyping.Float[Array, ‘T 2K 2K’]
- log_likelihood
Total log-likelihood scalar.
- Type:
jaxtyping.Float[Array, ‘’]
- covs: Float[Array, 'T 2K 2K']
Alias for field number 1
- log_likelihood: Float[Array, '']
Alias for field number 2
- means: Float[Array, 'T 2K']
Alias for field number 0
- class bellman_filter_dfsv.types.ParticleState(particles: Float[Array, '2K P'], log_weights: Float[Array, 'P'])[source]
Bases:
NamedTupleState container for Particle Filter.
- particles
Particle states (2K, num_particles).
- Type:
jaxtyping.Float[Array, ‘2K P’]
- log_weights
Log-weights for each particle (num_particles,).
- Type:
jaxtyping.Float[Array, ‘P’]
- log_weights: Float[Array, 'P']
Alias for field number 1
- particles: Float[Array, '2K P']
Alias for field number 0
- class bellman_filter_dfsv.types.RBPSResult(h_samples: Float[Array, 'M T K'], f_smooth_means: Float[Array, 'M T K'], f_smooth_covs: Float[Array, 'M T K K'], f_smooth_lag1_covs: Float[Array, 'M T_minus_1 K K'])[source]
Bases:
NamedTupleResult of Rao-Blackwellized Particle Smoothing. Contains M sampled trajectories and their conditional f-statistics.
- f_smooth_covs: Float[Array, 'M T K K']
Alias for field number 2
- f_smooth_lag1_covs: Float[Array, 'M T_minus_1 K K']
Alias for field number 3
- f_smooth_means: Float[Array, 'M T K']
Alias for field number 1
- h_samples: Float[Array, 'M T K']
Alias for field number 0
- class bellman_filter_dfsv.types.RBParticleState(h_particles: Float[Array, 'K P'], f_means: Float[Array, 'K P'], f_covs: Float[Array, 'K K P'], log_weights: Float[Array, 'P'])[source]
Bases:
NamedTupleState for Rao-Blackwellized Particle Filter.
- f_covs: Float[Array, 'K K P']
Alias for field number 2
- f_means: Float[Array, 'K P']
Alias for field number 1
- h_particles: Float[Array, 'K P']
Alias for field number 0
- log_weights: Float[Array, 'P']
Alias for field number 3
Filters
- class bellman_filter_dfsv.filters.BellmanFilter(params: DFSVParams)[source]
Bases:
ModuleBellman Information Filter for DFSV models.
- __init__(params: DFSVParams)[source]
Initializes the filter with model parameters.
- filter(observations: Float[Array, 'T N']) FilterResult[source]
Runs the filter over a sequence of observations.
- params: DFSVParams
- class bellman_filter_dfsv.filters.ParticleFilter(params: DFSVParams, num_particles: int = 1000, resample_threshold_frac: float = 0.5, seed: int = 42)[source]
Bases:
ModuleParticle Filter (SISR) for DFSV models.
- __init__(params: DFSVParams, num_particles: int = 1000, resample_threshold_frac: float = 0.5, seed: int = 42)[source]
Initializes the particle filter.
- filter(observations: Float[Array, 'T N']) ParticleFilterResult[source]
Runs the particle filter over a sequence of observations.
- params: DFSVParams
Smoothing
- class bellman_filter_dfsv.smoothing.SmootherResult(smoothed_means: Float[Array, 'T 2K'], smoothed_covs: Float[Array, 'T 2K 2K'], smoothed_lag1_covs: Float[Array, 'T 2K 2K'])[source]
Bases:
ModuleResult container for RTS smoother.
- __init__(smoothed_means: Float[Array, 'T 2K'], smoothed_covs: Float[Array, 'T 2K 2K'], smoothed_lag1_covs: Float[Array, 'T 2K 2K']) None
- smoothed_covs: Float[Array, 'T 2K 2K']
- smoothed_lag1_covs: Float[Array, 'T 2K 2K']
- smoothed_means: Float[Array, 'T 2K']
- bellman_filter_dfsv.smoothing.rts_smoother(params: DFSVParams, filter_means: Float[Array, 'T 2K'], filter_infos: Float[Array, 'T 2K 2K']) SmootherResult[source]
Runs the Rauch-Tung-Striebel (RTS) smoother adapted for information filter results.
- bellman_filter_dfsv.smoothing.run_rbps(params: DFSVParams, observations: Float[Array, 'T N'], num_particles: int = 100, num_trajectories: int = 20, seed: int = 42) RBPSResult[source]
Parameter Estimation
- bellman_filter_dfsv.estimation.constrain_params_default(p_unc: DFSVParams) DFSVParams[source]
Default transformation from unconstrained to constrained parameters.
- bellman_filter_dfsv.estimation.fit_em(observations: Float[Array, 'T N'], init_params: DFSVParams, num_particles: int = 200, num_trajectories: int = 20, max_iters: int = 50, verbose: bool = True) tuple[DFSVParams, list[DFSVParams]][source]
Fit DFSV model using Expectation-Maximization with RBPS.
- bellman_filter_dfsv.estimation.fit_mle(start_params: ~bellman_filter_dfsv.types.DFSVParams, observations: ~jaxtyping.Float[Array, 'T N'], learning_rate: float = 0.01, num_steps: int = 100, optimizer: ~optax._src.base.GradientTransformation | None = None, constrain_fn: ~collections.abc.Callable[[~bellman_filter_dfsv.types.DFSVParams], ~bellman_filter_dfsv.types.DFSVParams] = <function constrain_params_default>, unconstrain_fn: ~collections.abc.Callable[[~bellman_filter_dfsv.types.DFSVParams], ~bellman_filter_dfsv.types.DFSVParams] = <function unconstrain_params_default>, verbose: bool = True) tuple[DFSVParams, list[float]][source]
Fits DFSV parameters using Maximum Likelihood Estimation (MLE).
- Parameters:
start_params – Initial guess for parameters.
observations – Observed data matrix (Time x N).
learning_rate – Learning rate for Adam optimizer (default: 0.01).
num_steps – Number of optimization steps.
optimizer – Custom Optax optimizer (optional). If None, uses Adam.
constrain_fn – Function to map unconstrained -> constrained params.
unconstrain_fn – Function to map constrained -> unconstrained params.
verbose – Whether to print progress.
- Returns:
(optimized_params, loss_history)
- Return type:
- bellman_filter_dfsv.estimation.m_step(stats: EMSufficientStats, n_mu_phi_iters: int = 3) tuple[Float[Array, 'N K'], ...][source]
- bellman_filter_dfsv.estimation.rbps_to_suffstats(rbps_result: RBPSResult, observations: Float[Array, 'T N'], K: int) EMSufficientStats[source]
Convert RBPS samples to EM sufficient statistics by averaging over trajectories.
- bellman_filter_dfsv.estimation.unconstrain_params_default(p: DFSVParams) DFSVParams[source]
Default transformation from constrained to unconstrained parameters.
- bellman_filter_dfsv.estimation.update_Phi_f(stats: EMSufficientStats) Float[Array, 'K K'][source]
- bellman_filter_dfsv.estimation.update_Phi_h(stats: EMSufficientStats, mu: Float[Array, 'K']) Float[Array, 'K K'][source]
- bellman_filter_dfsv.estimation.update_Q_h(stats: EMSufficientStats, mu: Float[Array, 'K'], Phi_h: Float[Array, 'K K']) Float[Array, 'K K'][source]
- bellman_filter_dfsv.estimation.update_lambda_r(stats: EMSufficientStats) Float[Array, 'N K'][source]
- bellman_filter_dfsv.estimation.update_mu(stats: EMSufficientStats, Phi_h: Float[Array, 'K K']) Float[Array, 'K'][source]
- bellman_filter_dfsv.estimation.update_sigma2(stats: EMSufficientStats, lambda_r: Float[Array, 'N K']) Float[Array, 'N'][source]
Simulation
Simulation utilities for Dynamic Factor Stochastic Volatility models.
This module provides utilities to simulate DFSV models using JAX for the v2 architecture.
- bellman_filter_dfsv.simulation.simulate_DFSV(params: DFSVParams, T: int, *, key: Key[Array, ''] | UInt32[Array, '2'] | int = 42, f0: Float[Array, 'K'] | None = None, h0: Float[Array, 'K'] | None = None) tuple[Float[Array, 'T N'], Float[Array, 'T K'], Float[Array, 'T K']]
Simulate a Dynamic Factor Stochastic Volatility model.
- The DFSV model consists of:
Observation: r_t = λ_r f_t + e_t, e_t ~ N(0, Σ) Factor: f_t = Φ_f f_{t-1} + diag(exp(h_t/2)) ε_t, ε_t ~ N(0, I_K) Log-vol: h_t = μ + Φ_h (h_{t-1} - μ) + η_t, η_t ~ N(0, Q_h)
- Parameters:
params – DFSV model parameters.
T – Number of time steps to simulate.
key – JAX random key or integer seed. Default: 42.
f0 – Initial factors (K,). If None, defaults to zeros.
h0 – Initial log-volatilities (K,). If None, defaults to long-run mean μ.
- Returns:
returns: Simulated returns (T, N)
factors: Simulated latent factors (T, K)
log_vols: Simulated log-volatilities (T, K)
- Return type:
Tuple of (returns, factors, log_vols)
Example
>>> import jax.numpy as jnp >>> from bellman_filter_dfsv import DFSVParams, simulate_dfsv >>> params = DFSVParams( ... lambda_r=jnp.array([[0.8], [0.7]]), ... Phi_f=jnp.array([[0.7]]), ... Phi_h=jnp.array([[0.95]]), ... mu=jnp.array([-1.2]), ... sigma2=jnp.array([0.3, 0.25]), ... Q_h=jnp.array([[0.01]]) ... ) >>> returns, factors, log_vols = simulate_dfsv(params, T=100, key=42) >>> returns.shape (100, 2)
- bellman_filter_dfsv.simulation.simulate_dfsv(params: DFSVParams, T: int, *, key: Key[Array, ''] | UInt32[Array, '2'] | int = 42, f0: Float[Array, 'K'] | None = None, h0: Float[Array, 'K'] | None = None) tuple[Float[Array, 'T N'], Float[Array, 'T K'], Float[Array, 'T K']][source]
Simulate a Dynamic Factor Stochastic Volatility model.
- The DFSV model consists of:
Observation: r_t = λ_r f_t + e_t, e_t ~ N(0, Σ) Factor: f_t = Φ_f f_{t-1} + diag(exp(h_t/2)) ε_t, ε_t ~ N(0, I_K) Log-vol: h_t = μ + Φ_h (h_{t-1} - μ) + η_t, η_t ~ N(0, Q_h)
- Parameters:
params – DFSV model parameters.
T – Number of time steps to simulate.
key – JAX random key or integer seed. Default: 42.
f0 – Initial factors (K,). If None, defaults to zeros.
h0 – Initial log-volatilities (K,). If None, defaults to long-run mean μ.
- Returns:
returns: Simulated returns (T, N)
factors: Simulated latent factors (T, K)
log_vols: Simulated log-volatilities (T, K)
- Return type:
Tuple of (returns, factors, log_vols)
Example
>>> import jax.numpy as jnp >>> from bellman_filter_dfsv import DFSVParams, simulate_dfsv >>> params = DFSVParams( ... lambda_r=jnp.array([[0.8], [0.7]]), ... Phi_f=jnp.array([[0.7]]), ... Phi_h=jnp.array([[0.95]]), ... mu=jnp.array([-1.2]), ... sigma2=jnp.array([0.3, 0.25]), ... Q_h=jnp.array([[0.01]]) ... ) >>> returns, factors, log_vols = simulate_dfsv(params, T=100, key=42) >>> returns.shape (100, 2)