Skip to content

Commit 656b800

Browse files
Bugfixes for statespace/models/structural.py (#287)
Expand test coverage fix bugs in `structural.py`
1 parent 430c344 commit 656b800

File tree

9 files changed

+758
-180
lines changed

9 files changed

+758
-180
lines changed

docs/statespace/models.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Statespace Models
66
.. autosummary::
77
:toctree: generated
88

9-
BayesianARIMA
9+
BayesianSARIMA
1010
BayesianVARMAX
1111

1212
*********************

docs/statespace/models/structural.rst

+1-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,4 @@ Structural Components
1111
TimeSeasonality
1212
FrequencySeasonality
1313
MeasurementError
14-
15-
StructuralTimeSeries
16-
Component
14+
CycleComponent

pymc_experimental/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
handler = logging.StreamHandler()
2424
_log.addHandler(handler)
2525

26-
from pymc_experimental import distributions, gp, utils
26+
from pymc_experimental import distributions, gp, statespace, utils
2727
from pymc_experimental.inference.fit import fit
2828
from pymc_experimental.model.marginal_model import MarginalModel
2929
from pymc_experimental.model.model_api import as_model

pymc_experimental/statespace/core/statespace.py

+32-15
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,37 @@ def _insert_random_variables(self) -> List[Variable]:
636636
}
637637
self.subbed_ssm = graph_replace(matrices, replace=replacement_dict, strict=True)
638638

639+
def _register_matrices_with_pymc_model(self) -> List[pt.TensorVariable]:
640+
"""
641+
Add all statespace matrices to the PyMC model currently on the context stack as pm.Deterministic nodes, and
642+
adds named dimensions if they are found.
643+
644+
Returns
645+
-------
646+
registered_matrices: list of pt.TensorVariable
647+
List of statespace matrices, wrapped in pm.Deterministic
648+
"""
649+
650+
pm_mod = modelcontext(None)
651+
matrices = self.unpack_statespace()
652+
653+
registered_matrices = []
654+
for i, (matrix, name) in enumerate(zip(matrices, MATRIX_NAMES)):
655+
time_varying_ndim = 2 if name in VECTOR_VALUED else 3
656+
if not getattr(pm_mod, name, None):
657+
shape, dims = self._get_matrix_shape_and_dims(name)
658+
has_dims = dims is not None
659+
660+
if matrix.ndim == time_varying_ndim and has_dims:
661+
dims = (TIME_DIM,) + dims
662+
663+
x = pm.Deterministic(name, matrix, dims=dims)
664+
registered_matrices.append(x)
665+
else:
666+
registered_matrices.append(matrices[i])
667+
668+
return registered_matrices
669+
639670
def add_exogenous(self, exog: pt.TensorVariable) -> None:
640671
"""
641672
Add an exogenous process to the statespace model
@@ -746,7 +777,6 @@ def build_statespace_graph(
746777
pm_mod = modelcontext(None)
747778

748779
self._insert_random_variables()
749-
matrices = self.unpack_statespace()
750780
obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None)
751781

752782
self.data_len = data.shape[0]
@@ -758,20 +788,7 @@ def build_statespace_graph(
758788
missing_fill_value=missing_fill_value,
759789
)
760790

761-
registered_matrices = []
762-
for i, (matrix, name) in enumerate(zip(matrices, MATRIX_NAMES)):
763-
time_varying_ndim = 2 if name in VECTOR_VALUED else 3
764-
if not getattr(pm_mod, name, None):
765-
shape, dims = self._get_matrix_shape_and_dims(name)
766-
has_dims = dims is not None
767-
768-
if matrix.ndim == time_varying_ndim and has_dims:
769-
dims = (TIME_DIM,) + dims
770-
771-
x = pm.Deterministic(name, matrix, dims=dims)
772-
registered_matrices.append(x)
773-
else:
774-
registered_matrices.append(matrices[i])
791+
registered_matrices = self._register_matrices_with_pymc_model()
775792

776793
filter_outputs = self.kalman_filter.build_graph(
777794
pt.as_tensor_variable(data),

pymc_experimental/statespace/filters/distributions.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytensor.graph.basic import Node
1111

1212
floatX = pytensor.config.floatX
13+
COV_ZERO_TOL = 0
1314

1415
lgss_shape_message = (
1516
"The LinearGaussianStateSpace distribution needs shape information to be constructed. "
@@ -157,8 +158,11 @@ def step_fn(*args):
157158
middle_rng, a_innovation = pm.MvNormal.dist(mu=0, cov=Q, rng=rng).owner.outputs
158159
next_rng, y_innovation = pm.MvNormal.dist(mu=0, cov=H, rng=middle_rng).owner.outputs
159160

160-
a_next = c + T @ a + R @ a_innovation
161-
y_next = d + Z @ a_next + y_innovation
161+
a_mu = c + T @ a
162+
a_next = pt.switch(pt.all(pt.le(Q, COV_ZERO_TOL)), a_mu, a_mu + R @ a_innovation)
163+
164+
y_mu = d + Z @ a_next
165+
y_next = pt.switch(pt.all(pt.le(H, COV_ZERO_TOL)), y_mu, y_mu + y_innovation)
162166

163167
next_state = pt.concatenate([a_next, y_next], axis=0)
164168

@@ -168,7 +172,11 @@ def step_fn(*args):
168172
Z_init = Z_ if Z_ in non_sequences else Z_[0]
169173
H_init = H_ if H_ in non_sequences else H_[0]
170174

171-
init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng)
175+
init_y_ = pt.switch(
176+
pt.all(pt.le(H_init, COV_ZERO_TOL)),
177+
Z_init @ init_x_,
178+
pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng),
179+
)
172180
init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
173181

174182
statespace, updates = pytensor.scan(
@@ -216,6 +224,7 @@ def __new__(
216224
steps=None,
217225
mode=None,
218226
sequence_names=None,
227+
k_endog=None,
219228
**kwargs,
220229
):
221230
dims = kwargs.pop("dims", None)
@@ -239,11 +248,29 @@ def __new__(
239248
sequence_names=sequence_names,
240249
**kwargs,
241250
)
242-
243251
k_states = T.type.shape[0]
244252

245-
latent_states = latent_obs_combined[..., :k_states]
246-
obs_states = latent_obs_combined[..., k_states:]
253+
if k_endog is None and k_states is None:
254+
raise ValueError("Could not infer number of observed states, explicitly pass k_endog.")
255+
if k_endog is not None and k_states is not None:
256+
total_shape = latent_obs_combined.type.shape[-1]
257+
inferred_endog = total_shape - k_states
258+
if inferred_endog != k_endog:
259+
raise ValueError(
260+
f"Inferred k_endog does not agree with provided value ({inferred_endog} != {k_endog}). "
261+
f"It is not necessary to provide k_endog when the value can be inferred."
262+
)
263+
latent_slice = slice(None, -k_endog)
264+
obs_slice = slice(-k_endog, None)
265+
elif k_endog is None:
266+
latent_slice = slice(None, k_states)
267+
obs_slice = slice(k_states, None)
268+
else:
269+
latent_slice = slice(None, -k_endog)
270+
obs_slice = slice(-k_endog, None)
271+
272+
latent_states = latent_obs_combined[..., latent_slice]
273+
obs_states = latent_obs_combined[..., obs_slice]
247274

248275
latent_states = pm.Deterministic(f"{name}_latent", latent_states, dims=latent_dims)
249276
obs_states = pm.Deterministic(f"{name}_observed", obs_states, dims=obs_dims)

0 commit comments

Comments
 (0)