Source code for bellman_filter_dfsv.simulation

"""Simulation utilities for Dynamic Factor Stochastic Volatility models.

This module provides utilities to simulate DFSV models using JAX for the v2 architecture.
"""

import jax
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import Array, Float, PRNGKeyArray

from .types import DFSVParams


[docs] def simulate_dfsv( params: DFSVParams, T: int, *, key: PRNGKeyArray | int = 42, f0: Float[Array, "K"] | None = None, h0: Float[Array, "K"] | None = None, ) -> tuple[Float[Array, "T N"], Float[Array, "T K"], Float[Array, "T K"]]: """Simulate a Dynamic Factor Stochastic Volatility model. The DFSV model consists of: Observation: r_t = λ_r f_t + e_t, e_t ~ N(0, Σ) Factor: f_t = Φ_f f_{t-1} + diag(exp(h_t/2)) ε_t, ε_t ~ N(0, I_K) Log-vol: h_t = μ + Φ_h (h_{t-1} - μ) + η_t, η_t ~ N(0, Q_h) Args: params: DFSV model parameters. T: Number of time steps to simulate. key: JAX random key or integer seed. Default: 42. f0: Initial factors (K,). If None, defaults to zeros. h0: Initial log-volatilities (K,). If None, defaults to long-run mean μ. Returns: Tuple of (returns, factors, log_vols): - returns: Simulated returns (T, N) - factors: Simulated latent factors (T, K) - log_vols: Simulated log-volatilities (T, K) Example: >>> import jax.numpy as jnp >>> from bellman_filter_dfsv import DFSVParams, simulate_dfsv >>> params = DFSVParams( ... lambda_r=jnp.array([[0.8], [0.7]]), ... Phi_f=jnp.array([[0.7]]), ... Phi_h=jnp.array([[0.95]]), ... mu=jnp.array([-1.2]), ... sigma2=jnp.array([0.3, 0.25]), ... Q_h=jnp.array([[0.01]]) ... ) >>> returns, factors, log_vols = simulate_dfsv(params, T=100, key=42) >>> returns.shape (100, 2) """ # Handle key if isinstance(key, int): key = jr.PRNGKey(key) # Extract dimensions N, K = params.lambda_r.shape # Initialize states if f0 is None: f0 = jnp.zeros(K) if h0 is None: h0 = params.mu # Prepare covariance matrices Sigma = jnp.diag(params.sigma2) # Idiosyncratic covariance (N, N) chol_Sigma = jnp.linalg.cholesky(Sigma) # Lower triangular chol_Q_h = jnp.linalg.cholesky(params.Q_h) # Lower triangular # Split key for random draws key_h, key_f, key_r = jr.split(key, 3) # Generate all random shocks at once eta = jr.normal(key_h, (T, K)) # Log-vol innovations eps = jr.normal(key_f, (T, K)) # Factor innovations e = jr.normal(key_r, (T, N)) # Observation noise def step(carry, t): """Single simulation step.""" f_prev, h_prev = carry # Log-volatility transition: h_t = μ + Φ_h(h_{t-1} - μ) + η_t h_t = params.mu + params.Phi_h @ (h_prev - params.mu) + chol_Q_h @ eta[t] # Factor transition: f_t = Φ_f f_{t-1} + diag(exp(h_t/2)) ε_t vol_scale = jnp.exp(h_t / 2.0) f_t = params.Phi_f @ f_prev + vol_scale * eps[t] # Observation: r_t = λ_r f_t + e_t r_t = params.lambda_r @ f_t + chol_Sigma @ e[t] return (f_t, h_t), (r_t, f_t, h_t) # Run simulation _, (returns, factors, log_vols) = jax.lax.scan(step, (f0, h0), jnp.arange(T)) return returns, factors, log_vols
# Alias for backwards compatibility with examples simulate_DFSV = simulate_dfsv