From ca9b8b63e72e2bbbe0ad2e62be6ef6f82ae2ad7f Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 21 Nov 2024 23:57:15 +0800 Subject: [PATCH 1/5] Remove `SteadyStateFilder` Rename `CholeskyFilter` to `SquareRootFilter` to match the literature --- .../statespace/core/statespace.py | 4 +- .../statespace/filters/__init__.py | 4 +- .../statespace/filters/kalman_filter.py | 186 ++---------------- tests/statespace/test_kalman_filter.py | 15 +- 4 files changed, 22 insertions(+), 187 deletions(-) diff --git a/pymc_experimental/statespace/core/statespace.py b/pymc_experimental/statespace/core/statespace.py index c87e2ffd..9ffb60aa 100644 --- a/pymc_experimental/statespace/core/statespace.py +++ b/pymc_experimental/statespace/core/statespace.py @@ -18,9 +18,9 @@ from pymc_experimental.statespace.core.representation import PytensorRepresentation from pymc_experimental.statespace.filters import ( - CholeskyFilter, KalmanSmoother, SingleTimeseriesFilter, + SquareRootFilter, StandardFilter, SteadyStateFilter, UnivariateFilter, @@ -55,7 +55,7 @@ "univariate": UnivariateFilter, "steady_state": SteadyStateFilter, "single": SingleTimeseriesFilter, - "cholesky": CholeskyFilter, + "cholesky": SquareRootFilter, } diff --git a/pymc_experimental/statespace/filters/__init__.py b/pymc_experimental/statespace/filters/__init__.py index 15e3d899..13f1d7a4 100644 --- a/pymc_experimental/statespace/filters/__init__.py +++ b/pymc_experimental/statespace/filters/__init__.py @@ -1,7 +1,7 @@ from pymc_experimental.statespace.filters.distributions import LinearGaussianStateSpace from pymc_experimental.statespace.filters.kalman_filter import ( - CholeskyFilter, SingleTimeseriesFilter, + SquareRootFilter, StandardFilter, SteadyStateFilter, UnivariateFilter, @@ -14,6 +14,6 @@ "SteadyStateFilter", "KalmanSmoother", "SingleTimeseriesFilter", - "CholeskyFilter", + "SquareRootFilter", "LinearGaussianStateSpace", ] diff --git a/pymc_experimental/statespace/filters/kalman_filter.py b/pymc_experimental/statespace/filters/kalman_filter.py index 0d955d02..37c3137b 100644 --- a/pymc_experimental/statespace/filters/kalman_filter.py +++ b/pymc_experimental/statespace/filters/kalman_filter.py @@ -8,8 +8,7 @@ from pytensor.graph.basic import Variable from pytensor.raise_op import Assert from pytensor.tensor import TensorVariable -from pytensor.tensor.nlinalg import matrix_dot -from pytensor.tensor.slinalg import solve_discrete_are, solve_triangular +from pytensor.tensor.slinalg import solve_triangular from pymc_experimental.statespace.filters.utilities import ( quad_form_sym, @@ -55,15 +54,6 @@ def __init__(self, mode=None): non_seq_names : list[str] A list of names representing static statespace matrices. That is, inputs that will need to be provided to the `non_sequences` argument of `pytensor.scan` - - eye_states : TensorVariable - An identity matrix of shape (k_states, k_states), stored for computational efficiency - - eye_posdef : TensorVariable - An identity matrix of shape (k_posdef, k_posdef), stored for computational efficiency - - eye_endog : TensorVariable - An identity matrix of shape (k_endog, k_endog), stored for computational efficiency """ self.mode: str = mode @@ -74,44 +64,9 @@ def __init__(self, mode=None): self.n_posdef = None self.n_endog = None - self.eye_states: TensorVariable | None = None - self.eye_posdef: TensorVariable | None = None - self.eye_endog: TensorVariable | None = None self.missing_fill_value: float | None = None self.cov_jitter = None - def initialize_eyes(self, R: TensorVariable, Z: TensorVariable) -> None: - """ - Initialize identity matrices for of shapes repeated used in the kalman filtering equations and store them. - - It's surprisingly expensive for pytensor to create an identity matrix every time we need one - (see [1] for benchmarks). This function creates some identity matrices of useful sizes for the model - to re-use as a small optimization. - - Parameters - ---------- - R : TensorVariable - The tensor representing the selection matrix, called R in [2] - - Z : TensorVariable - The tensor representing the design matrix, called Z in [2]. - - Returns - ------- - None - - References - ---------- - .. [1] https://gist.github.com/jessegrabowski/acd3235833163943a11654d78a72f04b - .. [2] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods. - 2nd ed, Oxford University Press, 2012. - """ - - self.n_states, self.n_posdef, self.n_endog = R.shape[-2], R.shape[-1], Z.shape[-2] - self.eye_states = pt.eye(self.n_states) - self.eye_posdef = pt.eye(self.n_posdef) - self.eye_endog = pt.eye(self.n_endog) - def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q): """ Apply any checks on validity of inputs. For most filters this is just the identity function. @@ -141,10 +96,10 @@ def add_check_on_time_varying_shapes( list[TensorVariable] A list of tensors wrapped in an `Assert` `Op` that checks the shape of the 0th dimension on each is equal to the shape of the 0th dimension on the data. - - # TODO: The PytensorRepresentation object puts the time dimension last, should the reshaping happen here in - the Kalman filter, or in the StateSpaceModel, before passing into the KF? """ + # TODO: The PytensorRepresentation object puts the time dimension last, should the reshaping happen here in + # the Kalman filter, or in the StateSpaceModel, before passing into the KF? + params_with_assert = [ assert_time_varying_dim_correct(param, pt.eq(param.shape[0], data.shape[0])) for param in sequence_params @@ -166,7 +121,7 @@ def unpack_args(self, args) -> tuple: args = list(args) n_seq = len(self.seq_names) if n_seq == 0: - return args + return tuple(args) # The first arg is always y y = args.pop(0) @@ -202,7 +157,7 @@ def build_graph( return_updates=False, missing_fill_value=None, cov_jitter=None, - ) -> list[TensorVariable]: + ) -> list[TensorVariable] | tuple[list[TensorVariable], dict]: """ Construct the computation graph for the Kalman filter. See [1] for details. @@ -246,9 +201,11 @@ def build_graph( self.mode = mode self.missing_fill_value = missing_fill_value - self.initialize_eyes(R, Z) self.cov_jitter = cov_jitter + self.n_states, self.n_shocks = R.shape[-2:] + self.n_endog = Z.shape[-2] + data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q) sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq( @@ -643,7 +600,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag): F = Z.dot(PZT) + stabilize(H, self.cov_jitter) K = pt.linalg.solve(F.T, PZT.T, assume_a="pos", check_finite=False).T - I_KZ = self.eye_states - K.dot(Z) + I_KZ = pt.eye(self.n_states) - K.dot(Z) a_filtered = a + K.dot(v) P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H) @@ -662,7 +619,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag): return a_filtered, P_filtered, y_hat, F, ll -class CholeskyFilter(BaseFilter): +class SquareRootFilter(BaseFilter): """ Kalman filter with Cholesky factorization @@ -686,7 +643,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag): # If everything is missing, K = 0, IKZ = I K = solve_triangular(F_chol.T, solve_triangular(F_chol, PZT.T)).T - I_KZ = self.eye_states - K.dot(Z) + I_KZ = pt.eye(self.n_states) - K.dot(Z) a_filtered = a + K.dot(v) P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H) @@ -732,7 +689,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag): F = stabilize(Z.dot(PZT) + H, self.cov_jitter).ravel() K = PZT / F - I_KZ = self.eye_states - K.dot(Z) + I_KZ = pt.eye(self.n_states) - K.dot(Z) a_filtered = a + (K * v).ravel() @@ -743,123 +700,6 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag): return a_filtered, P_filtered, pt.atleast_1d(y_hat), pt.atleast_2d(F), ll -class SteadyStateFilter(BaseFilter): - """ - Kalman Filter using Steady State Covariance - - This filter avoids the need to invert the covariance matrix of innovations at each time step by solving the - Discrete Algebraic Riccati Equation associated with the filtering problem once and for all at initialization and - uses the resulting steady-state covariance matrix in each step. - - The innovation covariance matrix will always converge to the steady state value as T -> oo, so this filter will - only have differences from the standard approach in the early steps (T < 10?). A process of "learning" is lost. - """ - - def build_graph( - self, - data, - a0, - P0, - c, - d, - T, - Z, - R, - H, - Q, - mode=None, - return_updates=False, - missing_fill_value=None, - cov_jitter=None, - ) -> list[TensorVariable]: - """ - Need to override the base step to add an argument to self.update, passing F_inv at every step. - """ - if missing_fill_value is None: - missing_fill_value = MISSING_FILL - if cov_jitter is None: - cov_jitter = JITTER_DEFAULT - - self.mode = mode - self.missing_fill_value = missing_fill_value - self.cov_jitter = cov_jitter - self.initialize_eyes(R, Z) - - data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q) - sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq( - params, PARAM_NAMES - ) - self.seq_names = seq_names - self.non_seq_names = non_seq_names - c, d, T, Z, R, H, Q = params - - if len(sequences) > 0: - assert ValueError( - "All system matrices must be time-invariant to use the SteadyStateFilter" - ) - - P_steady = solve_discrete_are(T.T, Z.T, matrix_dot(R, Q, R.T), H) - F = matrix_dot(Z, P_steady, Z.T) + H - F_inv = pt.linalg.solve(F, pt.eye(F.shape[0]), assume_a="pos", check_finite=False) - - results, updates = pytensor.scan( - self.kalman_step, - sequences=[data], - outputs_info=[None, a0, None, None, P_steady, None, None], - non_sequences=[c, d, F_inv, T, Z, R, H, Q], - name="forward_kalman_pass", - mode=get_mode(self.mode), - ) - - return self._postprocess_scan_results(results, a0, P0, n=data.shape[0]) - - def update(self, a, P, c, d, F_inv, y, Z, H, all_nan_flag): - y_hat = Z.dot(a) + d - v = y - y_hat - - PZT = P.dot(Z.T) - - F = Z.dot(PZT) + stabilize(H, self.cov_jitter) - K = PZT.dot(F_inv) - - I_KZ = self.eye_states - K.dot(Z) - - a_filtered = a + K.dot(v) - P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H) - - inner_term = matrix_dot(v.T, F_inv, v) - ll = pt.switch( - all_nan_flag, - 0.0, - -0.5 * (MVN_CONST + pt.log(pt.linalg.det(F)) + inner_term).ravel()[0], - ) - - return a_filtered, P_filtered, y_hat, F, ll - - def kalman_step(self, y, a, P, c, d, F_inv, T, Z, R, H, Q): - """ - Need to override the base step to add an argument to self.update, passing F_inv at every step. - """ - - y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H) - a_filtered, P_filtered, obs_mu, obs_cov, ll = self.update( - y=y_masked, - a=a, - P=P, - c=c, - d=d, - F_inv=F_inv, - Z=Z_masked, - H=H_masked, - all_nan_flag=all_nan_flag, - ) - - P_filtered = stabilize(P_filtered, self.cov_jitter) - a_hat, P_hat = self.predict(a=a_filtered, P=P_filtered, c=c, T=T, R=R, Q=Q) - - return a_filtered, a_hat, obs_mu, P_filtered, P_hat, obs_cov, ll - - class UnivariateFilter(BaseFilter): """ The univariate kalman filter, described in [1], section 6.4.2, avoids inversion of the F matrix, as well as two diff --git a/tests/statespace/test_kalman_filter.py b/tests/statespace/test_kalman_filter.py index a8582e2f..6a9f4ec0 100644 --- a/tests/statespace/test_kalman_filter.py +++ b/tests/statespace/test_kalman_filter.py @@ -6,11 +6,10 @@ from numpy.testing import assert_allclose, assert_array_less from pymc_experimental.statespace.filters import ( - CholeskyFilter, KalmanSmoother, SingleTimeseriesFilter, + SquareRootFilter, StandardFilter, - SteadyStateFilter, UnivariateFilter, ) from pymc_experimental.statespace.filters.kalman_filter import BaseFilter @@ -33,25 +32,22 @@ RTOL = 1e-6 if floatX.endswith("64") else 1e-3 standard_inout = initialize_filter(StandardFilter()) -cholesky_inout = initialize_filter(CholeskyFilter()) +cholesky_inout = initialize_filter(SquareRootFilter()) univariate_inout = initialize_filter(UnivariateFilter()) single_inout = initialize_filter(SingleTimeseriesFilter()) -steadystate_inout = initialize_filter(SteadyStateFilter()) f_standard = pytensor.function(*standard_inout, on_unused_input="ignore") f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore") f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore") f_single_ts = pytensor.function(*single_inout, on_unused_input="ignore") -f_steady = pytensor.function(*steadystate_inout, on_unused_input="ignore") -filter_funcs = [f_standard, f_cholesky, f_univariate, f_single_ts, f_steady] +filter_funcs = [f_standard, f_cholesky, f_univariate, f_single_ts] filter_names = [ "StandardFilter", "CholeskyFilter", "UnivariateFilter", "SingleTimeSeriesFilter", - "SteadyStateFilter", ] output_names = [ @@ -247,8 +243,7 @@ def test_last_smoother_is_last_filtered(filter_func, output_idx, rng): assert_allclose(filtered[-1], smoothed[-1]) -# TODO: These tests omit the SteadyStateFilter, because it gives different results to StatsModels (reason to dump it?) -@pytest.mark.parametrize("filter_func", filter_funcs[:-1], ids=filter_names[:-1]) +@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) @pytest.mark.parametrize("n_missing", [0, 5], ids=["n_missing=0", "n_missing=5"]) @pytest.mark.skipif(floatX == "float32", reason="Tests are too sensitive for float32") def test_filters_match_statsmodel_output(filter_func, n_missing, rng): @@ -320,7 +315,7 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob @pytest.mark.parametrize( "filter", - [StandardFilter, SingleTimeseriesFilter, CholeskyFilter], + [StandardFilter, SingleTimeseriesFilter, SquareRootFilter], ids=["standard", "single_ts", "cholesky"], ) def test_kalman_filter_jax(filter): From aedd3a7cb6901595ee64fed65cd10e9fd697f1b7 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 21 Nov 2024 23:59:00 +0800 Subject: [PATCH 2/5] Remove `SteadyState` Filter --- pymc_experimental/statespace/core/statespace.py | 2 -- pymc_experimental/statespace/filters/__init__.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/pymc_experimental/statespace/core/statespace.py b/pymc_experimental/statespace/core/statespace.py index 9ffb60aa..1df66837 100644 --- a/pymc_experimental/statespace/core/statespace.py +++ b/pymc_experimental/statespace/core/statespace.py @@ -22,7 +22,6 @@ SingleTimeseriesFilter, SquareRootFilter, StandardFilter, - SteadyStateFilter, UnivariateFilter, ) from pymc_experimental.statespace.filters.distributions import ( @@ -53,7 +52,6 @@ FILTER_FACTORY = { "standard": StandardFilter, "univariate": UnivariateFilter, - "steady_state": SteadyStateFilter, "single": SingleTimeseriesFilter, "cholesky": SquareRootFilter, } diff --git a/pymc_experimental/statespace/filters/__init__.py b/pymc_experimental/statespace/filters/__init__.py index 13f1d7a4..b44b8380 100644 --- a/pymc_experimental/statespace/filters/__init__.py +++ b/pymc_experimental/statespace/filters/__init__.py @@ -3,7 +3,6 @@ SingleTimeseriesFilter, SquareRootFilter, StandardFilter, - SteadyStateFilter, UnivariateFilter, ) from pymc_experimental.statespace.filters.kalman_smoother import KalmanSmoother @@ -11,7 +10,6 @@ __all__ = [ "StandardFilter", "UnivariateFilter", - "SteadyStateFilter", "KalmanSmoother", "SingleTimeseriesFilter", "SquareRootFilter", From c3bc3659b332227e2d8377460b5d62d6f45e290a Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 23 Nov 2024 00:41:39 +0800 Subject: [PATCH 3/5] Use square root filter equations in `SquareRootFilter` --- .../statespace/filters/kalman_filter.py | 157 ++++++++++++++---- tests/statespace/test_kalman_filter.py | 27 ++- tests/statespace/utilities/test_helpers.py | 22 +-- 3 files changed, 147 insertions(+), 59 deletions(-) diff --git a/pymc_experimental/statespace/filters/kalman_filter.py b/pymc_experimental/statespace/filters/kalman_filter.py index 37c3137b..25e4837a 100644 --- a/pymc_experimental/statespace/filters/kalman_filter.py +++ b/pymc_experimental/statespace/filters/kalman_filter.py @@ -4,6 +4,7 @@ import pytensor import pytensor.tensor as pt +from pymc.pytensorf import constant_fold from pytensor.compile.mode import get_mode from pytensor.graph.basic import Variable from pytensor.raise_op import Assert @@ -203,8 +204,11 @@ def build_graph( self.missing_fill_value = missing_fill_value self.cov_jitter = cov_jitter - self.n_states, self.n_shocks = R.shape[-2:] - self.n_endog = Z.shape[-2] + [R_shape] = constant_fold([R.shape], raise_not_constant=False) + [Z_shape] = constant_fold([Z.shape], raise_not_constant=False) + + self.n_states, self.n_shocks = R_shape[-2:] + self.n_endog = Z_shape[-2] data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q) @@ -408,7 +412,7 @@ def predict(a, P, c, T, R, Q) -> tuple[TensorVariable, TensorVariable]: @staticmethod def update( - a, P, y, c, d, Z, H, all_nan_flag + a, P, y, d, Z, H, all_nan_flag ) -> tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable]: """ Perform the update step of the Kalman filter. @@ -419,7 +423,7 @@ def update( .. math:: \begin{align} - \\hat{y}_t &= Z_t a_{t | t-1} \\ + \\hat{y}_t &= Z_t a_{t | t-1} + d_t \\ v_t &= y_t - \\hat{y}_t \\ F_t &= Z_t P_{t | t-1} Z_t^T + H_t \\ a_{t|t} &= a_{t | t-1} + P_{t | t-1} Z_t^T F_t^{-1} v_t \\ @@ -435,8 +439,6 @@ def update( The current covariance matrix estimate, conditioned on information up to time t-1. y : TensorVariable The observation data at time t. - c : TensorVariable - The matrix c. d : TensorVariable The matrix d. Z : TensorVariable @@ -529,7 +531,7 @@ def kalman_step(self, *args) -> tuple: y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H) a_filtered, P_filtered, obs_mu, obs_cov, ll = self.update( - y=y_masked, a=a, c=c, d=d, P=P, Z=Z_masked, H=H_masked, all_nan_flag=all_nan_flag + y=y_masked, a=a, d=d, P=P, Z=Z_masked, H=H_masked, all_nan_flag=all_nan_flag ) P_filtered = stabilize(P_filtered, self.cov_jitter) @@ -545,7 +547,7 @@ class StandardFilter(BaseFilter): Basic Kalman Filter """ - def update(self, a, P, y, c, d, Z, H, all_nan_flag): + def update(self, a, P, y, d, Z, H, all_nan_flag): """ Compute one-step forecasts for observed states conditioned on information up to, but not including, the current timestep, `y_hat`, along with the forcast covariance matrix, `F`. Marginalize over observed states to obtain @@ -566,9 +568,6 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag): y : TensorVariable Observations at time t. - c : TensorVariable - Latent state bias term. - d : TensorVariable Observed state bias term. @@ -628,38 +627,128 @@ class SquareRootFilter(BaseFilter): """ - # TODO: Can the entire Kalman filter process be re-written, starting from P0_chol, so it's not necessary to compute - # cholesky(F) at every iteration? + def predict(self, a, P, c, T, R, Q): + """ + Compute one-step forecasts for the hidden states conditioned on information up to, but not including, the current + timestep, `a_hat`, along with the forcast covariance matrix, `P_hat`. + + .. warning:: + Very important -- In this function, $P$ is the **cholesky factor** of the covariance matrix, not the + covariance matrix itself. The name `P` is kept for consistency with the superclass. + """ + # Rename P to P_chol for clarity + P_chol = P + + a_hat = T.dot(a) + c + Q_chol = pt.linalg.cholesky(Q, lower=True) + + M = pt.horizontal_stack(T @ P_chol, R @ Q_chol).T + R_decomp = pt.linalg.qr(M, mode="r") + P_chol_hat = R_decomp[: self.n_states, : self.n_states].T + + return a_hat, P_chol_hat + + def update(self, a, P, y, d, Z, H, all_nan_flag): + """ + Compute posterior estimates of the hidden state distributions conditioned on the observed data, up to and + including the present timestep. Also compute the log-likelihood of the data given the one-step forecasts. + + .. warning:: + Very important -- In this function, $P$ is the **cholesky factor** of the covariance matrix, not the + covariance matrix itself. The name `P` is kept for consistency with the superclass. + """ + + # Rename P to P_chol for clarity + P_chol = P - def update(self, a, P, y, c, d, Z, H, all_nan_flag): y_hat = Z.dot(a) + d v = y - y_hat - PZT = P.dot(Z.T) + H_chol = pytensor.ifelse(pt.all(pt.eq(H, 0.0)), H, pt.linalg.cholesky(H, lower=True)) + + # The following notation comes from https://ipnpr.jpl.nasa.gov/progress_report/42-233/42-233A.pdf + # Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred], + # [0, L_pred]] + # The Schur decomposition of this matrix will be B (upper triangular). We are + # more insterested in B^T: + # Structure of B^T = [[chol(F), 0 ], + # [K @ chol(F), chol(P_filtered)] + zeros = pt.zeros((self.n_states, self.n_endog)) + upper = pt.horizontal_stack(H_chol, Z @ P_chol) + lower = pt.horizontal_stack(zeros, P_chol) + A_T = pt.vertical_stack(upper, lower) + B = pt.linalg.qr(A_T.T, mode="r").T + + F_chol = B[: self.n_endog, : self.n_endog] + K_F_chol = B[self.n_endog :, : self.n_endog] + P_chol_filtered = B[self.n_endog :, self.n_endog :] + + def compute_non_degenerate(P_chol_filtered, F_chol, K_F_chol, v): + a_filtered = a + K_F_chol @ solve_triangular(F_chol, v, lower=True) + + inner_term = solve_triangular( + F_chol, solve_triangular(F_chol, v, lower=True), lower=True + ) + loss = (v.T @ inner_term).ravel() + + # abs necessary because we're not guaranteed a positive diagonal from the schur decomposition + logdet = 2 * pt.log(pt.abs(pt.diag(F_chol))).sum() + + ll = -0.5 * (self.n_endog * (MVN_CONST + logdet) + loss)[0] + + return [a_filtered, P_chol_filtered, ll] + + def compute_degenerate(P_chol_filtered, F_chol, K_F_chol, v): + """ + If F is zero (usually because there were no observations this period), then we want: + K = 0, a = a, P = P, ll = 0 + """ + return [a, P_chol, pt.zeros(())] + + [a_filtered, P_chol_filtered, ll] = pytensor.ifelse( + pt.eq(all_nan_flag, 1.0), + compute_degenerate(P_chol_filtered, F_chol, K_F_chol, v), + compute_non_degenerate(P_chol_filtered, F_chol, K_F_chol, v), + ) - # If everything is missing, F will be [[0]] and F_chol will raise an error, so add identity to avoid the error - F = Z.dot(PZT) + stabilize(H, self.cov_jitter) - F_chol = pt.linalg.cholesky(F) + a_filtered = pt.specify_shape(a_filtered, (self.n_states,)) + P_chol_filtered = pt.specify_shape(P_chol_filtered, (self.n_states, self.n_states)) - # If everything is missing, K = 0, IKZ = I - K = solve_triangular(F_chol.T, solve_triangular(F_chol, PZT.T)).T - I_KZ = pt.eye(self.n_states) - K.dot(Z) + return a_filtered, P_chol_filtered, y_hat, F_chol, ll - a_filtered = a + K.dot(v) - P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H) + def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]: + """ + Convert the Cholesky factor of the covariance matrix back to the covariance matrix itself. + """ + results = super()._postprocess_scan_results(results, a0, P0, n) + ( + filtered_states, + predicted_states, + observed_states, + filtered_covariances_cholesky, + predicted_covariances_cholesky, + observed_covariances_cholesky, + loglike_obs, + ) = results - inner_term = solve_triangular(F_chol.T, solve_triangular(F_chol, v)) - n = y.shape[0] + def square_sequnece(L): + X = pt.einsum("...ij,...kj->...ik", L, L.copy()) + X = pt.specify_shape(X, (n, self.n_states, self.n_states)) + return X - ll = pt.switch( - all_nan_flag, - 0.0, - ( - -0.5 * (n * MVN_CONST + (v.T @ inner_term).ravel()) - pt.log(pt.diag(F_chol)).sum() - ).ravel()[0], - ) + filtered_covariances = square_sequnece(filtered_covariances_cholesky) + predicted_covariances = square_sequnece(predicted_covariances_cholesky) + observed_covariances = square_sequnece(observed_covariances_cholesky) - return a_filtered, P_filtered, y_hat, F, ll + return [ + filtered_states, + predicted_states, + observed_states, + filtered_covariances, + predicted_covariances, + observed_covariances, + loglike_obs, + ] class SingleTimeseriesFilter(BaseFilter): @@ -679,7 +768,7 @@ def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q): return data, a0, P0, c, d, T, Z, R, H, Q - def update(self, a, P, y, c, d, Z, H, all_nan_flag): + def update(self, a, P, y, d, Z, H, all_nan_flag): y_hat = d + Z.dot(a) v = y - y_hat.ravel() diff --git a/tests/statespace/test_kalman_filter.py b/tests/statespace/test_kalman_filter.py index 6a9f4ec0..02aba182 100644 --- a/tests/statespace/test_kalman_filter.py +++ b/tests/statespace/test_kalman_filter.py @@ -64,7 +64,7 @@ def test_base_class_update_raises(): filter = BaseFilter() - inputs = [None] * 8 + inputs = [None] * 7 with pytest.raises(NotImplementedError): filter.update(*inputs) @@ -214,6 +214,7 @@ def test_output_with_multiple_observed(filter_func, filter_name, rng): def test_missing_data(filter_func, filter_name, p, rng): m, r, n = 5, 1, 10 inputs = make_test_inputs(p, m, r, n, rng, missing_data=1) + if p > 1 and filter_name == "SingleTimeSeriesFilter": with pytest.raises( AssertionError, @@ -243,11 +244,16 @@ def test_last_smoother_is_last_filtered(filter_func, output_idx, rng): assert_allclose(filtered[-1], smoothed[-1]) -@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) +@pytest.mark.parametrize( + "filter_func, filter_name", zip(filter_funcs, filter_names), ids=filter_names +) @pytest.mark.parametrize("n_missing", [0, 5], ids=["n_missing=0", "n_missing=5"]) @pytest.mark.skipif(floatX == "float32", reason="Tests are too sensitive for float32") -def test_filters_match_statsmodel_output(filter_func, n_missing, rng): - fit_sm_mod, inputs = nile_test_test_helper(rng, n_missing) +def test_filters_match_statsmodel_output(filter_func, filter_name, n_missing, rng): + fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing) + if filter_name == "CholeskyFilter": + P0 = np.linalg.cholesky(P0) + inputs = [data, a0, P0, c, d, T, Z, R, H, Q] outputs = filter_func(*inputs) for output_idx, name in enumerate(output_names): @@ -294,6 +300,8 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob pytest.skip("Univariate filter not stable at half precision without measurement error") fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing) + if filter_name == "CholeskyFilter": + P0 = np.linalg.cholesky(P0) H *= int(obs_noise) inputs = [data, a0, P0, c, d, T, Z, R, H, Q] @@ -325,16 +333,7 @@ def test_kalman_filter_jax(filter): # TODO: Add UnivariateFilter to test; need to figure out the broadcasting issue when 2nd data dim is defined p, m, r, n = 1, 5, 1, 10 - inputs, outputs = initialize_filter(filter(), mode="JAX") - - # Shape of the data must be static for jax to know how long the scan is - data = inputs.pop(0) - data_specified = pt.specify_shape(data, (n, None)) - data_specified.name = "data" - inputs = [data, *inputs] - - outputs = pytensor.graph.clone_replace(outputs, {data: data_specified}) - + inputs, outputs = initialize_filter(filter(), mode="JAX", p=p, m=m, r=r, n=n) inputs_np = make_test_inputs(p, m, r, n, rng) f_jax = get_jaxified_graph(inputs, outputs) diff --git a/tests/statespace/utilities/test_helpers.py b/tests/statespace/utilities/test_helpers.py index 58da14b4..bac578bc 100644 --- a/tests/statespace/utilities/test_helpers.py +++ b/tests/statespace/utilities/test_helpers.py @@ -34,18 +34,18 @@ def load_nile_test_data(): return nile -def initialize_filter(kfilter, mode=None): +def initialize_filter(kfilter, mode=None, p=None, m=None, r=None, n=None): ksmoother = KalmanSmoother() - data = pt.matrix(name="data", dtype=floatX) - a0 = pt.vector(name="a0", dtype=floatX) - P0 = pt.matrix(name="P0", dtype=floatX) - c = pt.vector(name="c", dtype=floatX) - d = pt.vector(name="d", dtype=floatX) - Q = pt.matrix(name="Q", dtype=floatX) - H = pt.matrix(name="H", dtype=floatX) - T = pt.matrix(name="T", dtype=floatX) - R = pt.matrix(name="R", dtype=floatX) - Z = pt.matrix(name="Z", dtype=floatX) + data = pt.tensor(name="data", dtype=floatX, shape=(n, p)) + a0 = pt.tensor(name="x0", dtype=floatX, shape=(m,)) + P0 = pt.tensor(name="P0", dtype=floatX, shape=(m, m)) + c = pt.tensor(name="c", dtype=floatX, shape=(m,)) + d = pt.tensor(name="d", dtype=floatX, shape=(p,)) + Q = pt.tensor(name="Q", dtype=floatX, shape=(r, r)) + H = pt.tensor(name="H", dtype=floatX, shape=(p, p)) + T = pt.tensor(name="T", dtype=floatX, shape=(m, m)) + R = pt.tensor(name="R", dtype=floatX, shape=(m, r)) + Z = pt.tensor(name="Z", dtype=floatX, shape=(p, m)) inputs = [data, a0, P0, c, d, T, Z, R, H, Q] From 08995be8800f2bdf4378123a8e80f7cee9b8eba4 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 23 Nov 2024 00:44:25 +0800 Subject: [PATCH 4/5] Remove `SingleTimeSeriesFilter` --- .../statespace/core/statespace.py | 2 - .../statespace/filters/__init__.py | 2 - .../statespace/filters/kalman_filter.py | 43 +--------------- tests/statespace/test_kalman_filter.py | 50 ++++++------------- 4 files changed, 17 insertions(+), 80 deletions(-) diff --git a/pymc_experimental/statespace/core/statespace.py b/pymc_experimental/statespace/core/statespace.py index 1df66837..c90bb9b4 100644 --- a/pymc_experimental/statespace/core/statespace.py +++ b/pymc_experimental/statespace/core/statespace.py @@ -19,7 +19,6 @@ from pymc_experimental.statespace.core.representation import PytensorRepresentation from pymc_experimental.statespace.filters import ( KalmanSmoother, - SingleTimeseriesFilter, SquareRootFilter, StandardFilter, UnivariateFilter, @@ -52,7 +51,6 @@ FILTER_FACTORY = { "standard": StandardFilter, "univariate": UnivariateFilter, - "single": SingleTimeseriesFilter, "cholesky": SquareRootFilter, } diff --git a/pymc_experimental/statespace/filters/__init__.py b/pymc_experimental/statespace/filters/__init__.py index b44b8380..1ee3c707 100644 --- a/pymc_experimental/statespace/filters/__init__.py +++ b/pymc_experimental/statespace/filters/__init__.py @@ -1,6 +1,5 @@ from pymc_experimental.statespace.filters.distributions import LinearGaussianStateSpace from pymc_experimental.statespace.filters.kalman_filter import ( - SingleTimeseriesFilter, SquareRootFilter, StandardFilter, UnivariateFilter, @@ -11,7 +10,6 @@ "StandardFilter", "UnivariateFilter", "KalmanSmoother", - "SingleTimeseriesFilter", "SquareRootFilter", "LinearGaussianStateSpace", ] diff --git a/pymc_experimental/statespace/filters/kalman_filter.py b/pymc_experimental/statespace/filters/kalman_filter.py index 25e4837a..d5f806ec 100644 --- a/pymc_experimental/statespace/filters/kalman_filter.py +++ b/pymc_experimental/statespace/filters/kalman_filter.py @@ -21,7 +21,6 @@ MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64")) PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"] -assert_data_is_1d = Assert("UnivariateTimeSeries filter requires data be at most 1-dimensional") assert_time_varying_dim_correct = Assert( "The first dimension of a time varying matrix (the time dimension) must be " "equal to the first dimension of the data (the time dimension)." @@ -751,50 +750,12 @@ def square_sequnece(L): ] -class SingleTimeseriesFilter(BaseFilter): - """ - Kalman filter optimized for univariate timeseries - - If there is only a single observed timeseries, regardless of the number of hidden states, there is no need to - perform a matrix inversion anywhere in the filter. - """ - - # TODO: This class should eventually be made irrelevant by pytensor re-writes. - def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q): - """ - Wrap the data in an `Assert` `Op` to ensure there is only one observed state. - """ - data = assert_data_is_1d(data, pt.eq(data.shape[1], 1)) - - return data, a0, P0, c, d, T, Z, R, H, Q - - def update(self, a, P, y, d, Z, H, all_nan_flag): - y_hat = d + Z.dot(a) - v = y - y_hat.ravel() - - PZT = P.dot(Z.T) - - # F is scalar, K is a column vector - F = stabilize(Z.dot(PZT) + H, self.cov_jitter).ravel() - - K = PZT / F - I_KZ = pt.eye(self.n_states) - K.dot(Z) - - a_filtered = a + (K * v).ravel() - - P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H) - - ll = pt.switch(all_nan_flag, 0.0, -0.5 * (MVN_CONST + pt.log(F) + v**2 / F)).ravel()[0] - - return a_filtered, P_filtered, pt.atleast_1d(y_hat), pt.atleast_2d(F), ll - - class UnivariateFilter(BaseFilter): """ The univariate kalman filter, described in [1], section 6.4.2, avoids inversion of the F matrix, as well as two matrix multiplications, at the cost of an additional loop. Note that the name doesn't mean there's only one - observed time series, that's the SingleTimeSeries filter. This is called univariate because it updates the state - mean and covariance matrices one variable at a time, using an inner-inner loop. + observed time series. This is called univariate because it updates the state mean and covariance matrices one + variable at a time, using an inner-inner loop. This is useful when states are perfectly observed, because the F matrix can easily become degenerate in these cases. diff --git a/tests/statespace/test_kalman_filter.py b/tests/statespace/test_kalman_filter.py index 02aba182..6b221657 100644 --- a/tests/statespace/test_kalman_filter.py +++ b/tests/statespace/test_kalman_filter.py @@ -7,7 +7,6 @@ from pymc_experimental.statespace.filters import ( KalmanSmoother, - SingleTimeseriesFilter, SquareRootFilter, StandardFilter, UnivariateFilter, @@ -34,20 +33,17 @@ standard_inout = initialize_filter(StandardFilter()) cholesky_inout = initialize_filter(SquareRootFilter()) univariate_inout = initialize_filter(UnivariateFilter()) -single_inout = initialize_filter(SingleTimeseriesFilter()) f_standard = pytensor.function(*standard_inout, on_unused_input="ignore") f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore") f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore") -f_single_ts = pytensor.function(*single_inout, on_unused_input="ignore") -filter_funcs = [f_standard, f_cholesky, f_univariate, f_single_ts] +filter_funcs = [f_standard, f_cholesky, f_univariate] filter_names = [ "StandardFilter", "CholeskyFilter", "UnivariateFilter", - "SingleTimeSeriesFilter", ] output_names = [ @@ -191,20 +187,12 @@ def test_output_with_multiple_observed(filter_func, filter_name, rng): p, m, r, n = 5, 5, 1, 10 inputs = make_test_inputs(p, m, r, n, rng) - if filter_name == "SingleTimeSeriesFilter": - with pytest.raises( - AssertionError, - match="UnivariateTimeSeries filter requires data be at most 1-dimensional", - ): - filter_func(*inputs) - - else: - outputs = filter_func(*inputs) - for output_idx, name in enumerate(output_names): - expected_output = get_expected_shape(name, p, m, r, n) - assert ( - outputs[output_idx].shape == expected_output - ), f"Shape of {name} does not match expected" + outputs = filter_func(*inputs) + for output_idx, name in enumerate(output_names): + expected_output = get_expected_shape(name, p, m, r, n) + assert ( + outputs[output_idx].shape == expected_output + ), f"Shape of {name} does not match expected" @pytest.mark.parametrize( @@ -215,20 +203,12 @@ def test_missing_data(filter_func, filter_name, p, rng): m, r, n = 5, 1, 10 inputs = make_test_inputs(p, m, r, n, rng, missing_data=1) - if p > 1 and filter_name == "SingleTimeSeriesFilter": - with pytest.raises( - AssertionError, - match="UnivariateTimeSeries filter requires data be at most 1-dimensional", - ): - filter_func(*inputs) - - else: - outputs = filter_func(*inputs) - for output_idx, name in enumerate(output_names): - expected_output = get_expected_shape(name, p, m, r, n) - assert ( - outputs[output_idx].shape == expected_output - ), f"Shape of {name} does not match expected" + outputs = filter_func(*inputs) + for output_idx, name in enumerate(output_names): + expected_output = get_expected_shape(name, p, m, r, n) + assert ( + outputs[output_idx].shape == expected_output + ), f"Shape of {name} does not match expected" @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) @@ -323,8 +303,8 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob @pytest.mark.parametrize( "filter", - [StandardFilter, SingleTimeseriesFilter, SquareRootFilter], - ids=["standard", "single_ts", "cholesky"], + [StandardFilter, SquareRootFilter], + ids=["standard", "cholesky"], ) def test_kalman_filter_jax(filter): pytest.importorskip("jax") From 513a9d035e737db4c86c66bedc2d85986ecf6d7f Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 23 Nov 2024 11:52:10 +0800 Subject: [PATCH 5/5] Remove tests referencing old code --- pymc_experimental/statespace/filters/kalman_filter.py | 10 +++++----- tests/statespace/test_distributions.py | 2 -- tests/statespace/test_statespace.py | 6 ------ 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/pymc_experimental/statespace/filters/kalman_filter.py b/pymc_experimental/statespace/filters/kalman_filter.py index d5f806ec..edc15e29 100644 --- a/pymc_experimental/statespace/filters/kalman_filter.py +++ b/pymc_experimental/statespace/filters/kalman_filter.py @@ -730,14 +730,14 @@ def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]: loglike_obs, ) = results - def square_sequnece(L): + def square_sequnece(L, k): X = pt.einsum("...ij,...kj->...ik", L, L.copy()) - X = pt.specify_shape(X, (n, self.n_states, self.n_states)) + X = pt.specify_shape(X, (n, k, k)) return X - filtered_covariances = square_sequnece(filtered_covariances_cholesky) - predicted_covariances = square_sequnece(predicted_covariances_cholesky) - observed_covariances = square_sequnece(observed_covariances_cholesky) + filtered_covariances = square_sequnece(filtered_covariances_cholesky, k=self.n_states) + predicted_covariances = square_sequnece(predicted_covariances_cholesky, k=self.n_states) + observed_covariances = square_sequnece(observed_covariances_cholesky, k=self.n_endog) return [ filtered_states, diff --git a/tests/statespace/test_distributions.py b/tests/statespace/test_distributions.py index ab55eeba..9deaa3d6 100644 --- a/tests/statespace/test_distributions.py +++ b/tests/statespace/test_distributions.py @@ -38,8 +38,6 @@ "standard", "cholesky", "univariate", - "single", - "steady_state", ] diff --git a/tests/statespace/test_statespace.py b/tests/statespace/test_statespace.py index 83d0babc..0024bd2e 100644 --- a/tests/statespace/test_statespace.py +++ b/tests/statespace/test_statespace.py @@ -234,12 +234,6 @@ def test_invalid_filter_name_raises(): mod = make_statespace_mod(k_endog=1, k_states=5, k_posdef=1, filter_type="invalid_filter") -def test_singleseriesfilter_raises_if_k_endog_gt_one(): - msg = 'Cannot use filter_type = "single" with multiple observed time series' - with pytest.raises(ValueError, match=msg): - mod = make_statespace_mod(k_endog=10, k_states=5, k_posdef=1, filter_type="single") - - def test_unpack_before_insert_raises(rng): p, m, r, n = 2, 5, 1, 10 data, *inputs = make_test_inputs(p, m, r, n, rng, missing_data=0)