Skip to content

Bugfixes for statespace/models/structural.py #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 18, 2023
2 changes: 1 addition & 1 deletion docs/statespace/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Statespace Models
.. autosummary::
:toctree: generated

BayesianARIMA
BayesianSARIMA
BayesianVARMAX

*********************
Expand Down
4 changes: 1 addition & 3 deletions docs/statespace/models/structural.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,4 @@ Structural Components
TimeSeasonality
FrequencySeasonality
MeasurementError

StructuralTimeSeries
Component
CycleComponent
2 changes: 1 addition & 1 deletion pymc_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
handler = logging.StreamHandler()
_log.addHandler(handler)

from pymc_experimental import distributions, gp, utils
from pymc_experimental import distributions, gp, statespace, utils
from pymc_experimental.inference.fit import fit
from pymc_experimental.model.marginal_model import MarginalModel
from pymc_experimental.model.model_api import as_model
47 changes: 32 additions & 15 deletions pymc_experimental/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,37 @@ def _insert_random_variables(self) -> List[Variable]:
}
self.subbed_ssm = graph_replace(matrices, replace=replacement_dict, strict=True)

def _register_matrices_with_pymc_model(self) -> List[pt.TensorVariable]:
"""
Add all statespace matrices to the PyMC model currently on the context stack as pm.Deterministic nodes, and
adds named dimensions if they are found.

Returns
-------
registered_matrices: list of pt.TensorVariable
List of statespace matrices, wrapped in pm.Deterministic
"""

pm_mod = modelcontext(None)
matrices = self.unpack_statespace()

registered_matrices = []
for i, (matrix, name) in enumerate(zip(matrices, MATRIX_NAMES)):
time_varying_ndim = 2 if name in VECTOR_VALUED else 3
if not getattr(pm_mod, name, None):
shape, dims = self._get_matrix_shape_and_dims(name)
has_dims = dims is not None

if matrix.ndim == time_varying_ndim and has_dims:
dims = (TIME_DIM,) + dims

x = pm.Deterministic(name, matrix, dims=dims)
registered_matrices.append(x)
else:
registered_matrices.append(matrices[i])

return registered_matrices

def add_exogenous(self, exog: pt.TensorVariable) -> None:
"""
Add an exogenous process to the statespace model
Expand Down Expand Up @@ -746,7 +777,6 @@ def build_statespace_graph(
pm_mod = modelcontext(None)

self._insert_random_variables()
matrices = self.unpack_statespace()
obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None)

self.data_len = data.shape[0]
Expand All @@ -758,20 +788,7 @@ def build_statespace_graph(
missing_fill_value=missing_fill_value,
)

registered_matrices = []
for i, (matrix, name) in enumerate(zip(matrices, MATRIX_NAMES)):
time_varying_ndim = 2 if name in VECTOR_VALUED else 3
if not getattr(pm_mod, name, None):
shape, dims = self._get_matrix_shape_and_dims(name)
has_dims = dims is not None

if matrix.ndim == time_varying_ndim and has_dims:
dims = (TIME_DIM,) + dims

x = pm.Deterministic(name, matrix, dims=dims)
registered_matrices.append(x)
else:
registered_matrices.append(matrices[i])
registered_matrices = self._register_matrices_with_pymc_model()

filter_outputs = self.kalman_filter.build_graph(
pt.as_tensor_variable(data),
Expand Down
39 changes: 33 additions & 6 deletions pymc_experimental/statespace/filters/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytensor.graph.basic import Node

floatX = pytensor.config.floatX
COV_ZERO_TOL = 0

lgss_shape_message = (
"The LinearGaussianStateSpace distribution needs shape information to be constructed. "
Expand Down Expand Up @@ -157,8 +158,11 @@ def step_fn(*args):
middle_rng, a_innovation = pm.MvNormal.dist(mu=0, cov=Q, rng=rng).owner.outputs
next_rng, y_innovation = pm.MvNormal.dist(mu=0, cov=H, rng=middle_rng).owner.outputs

a_next = c + T @ a + R @ a_innovation
y_next = d + Z @ a_next + y_innovation
a_mu = c + T @ a
a_next = pt.switch(pt.all(pt.le(Q, COV_ZERO_TOL)), a_mu, a_mu + R @ a_innovation)

y_mu = d + Z @ a_next
y_next = pt.switch(pt.all(pt.le(H, COV_ZERO_TOL)), y_mu, y_mu + y_innovation)

next_state = pt.concatenate([a_next, y_next], axis=0)

Expand All @@ -168,7 +172,11 @@ def step_fn(*args):
Z_init = Z_ if Z_ in non_sequences else Z_[0]
H_init = H_ if H_ in non_sequences else H_[0]

init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng)
init_y_ = pt.switch(
pt.all(pt.le(H_init, COV_ZERO_TOL)),
Z_init @ init_x_,
pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng),
)
init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)

statespace, updates = pytensor.scan(
Expand Down Expand Up @@ -216,6 +224,7 @@ def __new__(
steps=None,
mode=None,
sequence_names=None,
k_endog=None,
**kwargs,
):
dims = kwargs.pop("dims", None)
Expand All @@ -239,11 +248,29 @@ def __new__(
sequence_names=sequence_names,
**kwargs,
)

k_states = T.type.shape[0]

latent_states = latent_obs_combined[..., :k_states]
obs_states = latent_obs_combined[..., k_states:]
if k_endog is None and k_states is None:
raise ValueError("Could not infer number of observed states, explicitly pass k_endog.")
if k_endog is not None and k_states is not None:
total_shape = latent_obs_combined.type.shape[-1]
inferred_endog = total_shape - k_states
if inferred_endog != k_endog:
raise ValueError(
f"Inferred k_endog does not agree with provided value ({inferred_endog} != {k_endog}). "
f"It is not necessary to provide k_endog when the value can be inferred."
)
latent_slice = slice(None, -k_endog)
obs_slice = slice(-k_endog, None)
elif k_endog is None:
latent_slice = slice(None, k_states)
obs_slice = slice(k_states, None)
else:
latent_slice = slice(None, -k_endog)
obs_slice = slice(-k_endog, None)

latent_states = latent_obs_combined[..., latent_slice]
obs_states = latent_obs_combined[..., obs_slice]

latent_states = pm.Deterministic(f"{name}_latent", latent_states, dims=latent_dims)
obs_states = pm.Deterministic(f"{name}_observed", obs_states, dims=obs_dims)
Expand Down
Loading