Skip to content

Commit eab41fa

Browse files
Remove mode argument passed from Statespace.build_statespace_graph to scan (#482)
* Remove `mode` argument from everywhere * Add deprecation warning to mode argument of `build_statespace_graph` Add mode argument to statespace constructors * Rename tests -> test * Remove _set_default_mode
1 parent 9589f28 commit eab41fa

16 files changed

+154
-87
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import warnings
23

34
from collections.abc import Callable, Sequence
45
from typing import Any, Literal
@@ -14,7 +15,6 @@
1415
from pymc.model.transform.optimization import freeze_dims_and_data
1516
from pymc.util import RandomState
1617
from pytensor import Variable, graph_replace
17-
from pytensor.compile import get_mode
1818
from rich.box import SIMPLE_HEAD
1919
from rich.console import Console
2020
from rich.table import Table
@@ -99,6 +99,13 @@ class PyMCStateSpace:
9999
compute the observation errors. If False, these errors are deterministically zero; if True, they are sampled
100100
from a multivariate normal.
101101
102+
mode: str or Mode, optional
103+
Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
104+
``forecast``. The mode does **not** effect calls to ``pm.sample``.
105+
106+
Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
107+
to all sampling methods.
108+
102109
Notes
103110
-----
104111
Based on the statsmodels statespace implementation https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py,
@@ -221,8 +228,8 @@ def __init__(
221228
filter_type: str = "standard",
222229
verbose: bool = True,
223230
measurement_error: bool = False,
231+
mode: str | None = None,
224232
):
225-
self._fit_mode: str | None = None
226233
self._fit_coords: dict[str, Sequence[str]] | None = None
227234
self._fit_dims: dict[str, Sequence[str]] | None = None
228235
self._fit_data: pt.TensorVariable | None = None
@@ -237,6 +244,7 @@ def __init__(
237244
self.k_states = k_states
238245
self.k_posdef = k_posdef
239246
self.measurement_error = measurement_error
247+
self.mode = mode
240248

241249
# All models contain a state space representation and a Kalman filter
242250
self.ssm = PytensorRepresentation(k_endog, k_states, k_posdef)
@@ -819,11 +827,11 @@ def build_statespace_graph(
819827
self,
820828
data: np.ndarray | pd.DataFrame | pt.TensorVariable,
821829
register_data: bool = True,
822-
mode: str | None = None,
823830
missing_fill_value: float | None = None,
824831
cov_jitter: float | None = JITTER_DEFAULT,
825832
mvn_method: Literal["cholesky", "eigh", "svd"] = "svd",
826833
save_kalman_filter_outputs_in_idata: bool = False,
834+
mode: str | None = None,
827835
) -> None:
828836
"""
829837
Given a parameter vector `theta`, constructs the full computational graph describing the state space model and
@@ -877,7 +885,25 @@ def build_statespace_graph(
877885
save_kalman_filter_outputs_in_idata: bool, optional, default=False
878886
If True, Kalman Filter outputs will be saved in the model as deterministics. Useful for debugging, but
879887
should not be necessary for the majority of users.
888+
889+
mode: str, optional
890+
Pytensor mode to use when compiling the graph. This will be saved as a model attribute and used when
891+
compiling sampling functions (e.g. ``sample_conditional_prior``).
892+
893+
.. deprecated:: 0.2.5
894+
The `mode` argument is deprecated and will be removed in a future version. Pass ``mode`` to the
895+
model constructor, or manually specify ``compile_kwargs`` in sampling functions instead.
896+
880897
"""
898+
if mode is not None:
899+
warnings.warn(
900+
"The `mode` argument is deprecated and will be removed in a future version. "
901+
"Pass `mode` to the model constructor, or manually specify `compile_kwargs` in sampling functions"
902+
" instead.",
903+
DeprecationWarning,
904+
)
905+
self.mode = mode
906+
881907
pm_mod = modelcontext(None)
882908

883909
self._insert_random_variables()
@@ -898,7 +924,6 @@ def build_statespace_graph(
898924
filter_outputs = self.kalman_filter.build_graph(
899925
pt.as_tensor_variable(data),
900926
*self.unpack_statespace(),
901-
mode=mode,
902927
missing_fill_value=missing_fill_value,
903928
cov_jitter=cov_jitter,
904929
)
@@ -909,7 +934,7 @@ def build_statespace_graph(
909934
filtered_covariances, predicted_covariances, observed_covariances = covs
910935
if save_kalman_filter_outputs_in_idata:
911936
smooth_states, smooth_covariances = self._build_smoother_graph(
912-
filtered_states, filtered_covariances, self.unpack_statespace(), mode=mode
937+
filtered_states, filtered_covariances, self.unpack_statespace()
913938
)
914939
all_kf_outputs = [*states, smooth_states, *covs, smooth_covariances]
915940
self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs)
@@ -929,7 +954,6 @@ def build_statespace_graph(
929954

930955
self._fit_coords = pm_mod.coords.copy()
931956
self._fit_dims = pm_mod.named_vars_to_dims.copy()
932-
self._fit_mode = mode
933957

934958
def _build_smoother_graph(
935959
self,
@@ -974,7 +998,7 @@ def _build_smoother_graph(
974998
*_, T, Z, R, H, Q = matrices
975999

9761000
smooth_states, smooth_covariances = self.kalman_smoother.build_graph(
977-
T, R, Q, filtered_states, filtered_covariances, mode=mode, cov_jitter=cov_jitter
1001+
T, R, Q, filtered_states, filtered_covariances, cov_jitter=cov_jitter
9781002
)
9791003
smooth_states.name = "smooth_states"
9801004
smooth_covariances.name = "smooth_covariances"
@@ -1092,7 +1116,6 @@ def _kalman_filter_outputs_from_dummy_graph(
10921116
R,
10931117
H,
10941118
Q,
1095-
mode=self._fit_mode,
10961119
)
10971120

10981121
filter_outputs.pop(-1)
@@ -1102,7 +1125,7 @@ def _kalman_filter_outputs_from_dummy_graph(
11021125
filtered_covariances, predicted_covariances, _ = covariances
11031126

11041127
[smoothed_states, smoothed_covariances] = self.kalman_smoother.build_graph(
1105-
T, R, Q, filtered_states, filtered_covariances, mode=self._fit_mode
1128+
T, R, Q, filtered_states, filtered_covariances
11061129
)
11071130

11081131
grouped_outputs = [
@@ -1164,6 +1187,9 @@ def _sample_conditional(
11641187
_verify_group(group)
11651188
group_idata = getattr(idata, group)
11661189

1190+
compile_kwargs = kwargs.pop("compile_kwargs", {})
1191+
compile_kwargs.setdefault("mode", self.mode)
1192+
11671193
with pm.Model(coords=self._fit_coords) as forward_model:
11681194
(
11691195
[
@@ -1229,8 +1255,8 @@ def _sample_conditional(
12291255
for name in FILTER_OUTPUT_TYPES
12301256
for suffix in ["", "_observed"]
12311257
],
1232-
compile_kwargs={"mode": get_mode(self._fit_mode)},
12331258
random_seed=random_seed,
1259+
compile_kwargs=compile_kwargs,
12341260
**kwargs,
12351261
)
12361262

@@ -1296,6 +1322,10 @@ def _sample_unconditional(
12961322
the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`.
12971323
"""
12981324
_verify_group(group)
1325+
1326+
compile_kwargs = kwargs.pop("compile_kwargs", {})
1327+
compile_kwargs.setdefault("mode", self.mode)
1328+
12991329
group_idata = getattr(idata, group)
13001330
dims = None
13011331
temp_coords = self._fit_coords.copy()
@@ -1338,7 +1368,6 @@ def _sample_unconditional(
13381368
*matrices,
13391369
steps=steps,
13401370
dims=dims,
1341-
mode=self._fit_mode,
13421371
method=mvn_method,
13431372
sequence_names=self.kalman_filter.seq_names,
13441373
k_endog=self.k_endog,
@@ -1354,8 +1383,8 @@ def _sample_unconditional(
13541383
idata_unconditional = pm.sample_posterior_predictive(
13551384
group_idata,
13561385
var_names=[f"{group}_latent", f"{group}_observed"],
1357-
compile_kwargs={"mode": self._fit_mode},
13581386
random_seed=random_seed,
1387+
compile_kwargs=compile_kwargs,
13591388
**kwargs,
13601389
)
13611390

@@ -1583,7 +1612,7 @@ def sample_unconditional_posterior(
15831612
)
15841613

15851614
def sample_statespace_matrices(
1586-
self, idata, matrix_names: str | list[str] | None, group: str = "posterior"
1615+
self, idata, matrix_names: str | list[str] | None, group: str = "posterior", **kwargs
15871616
):
15881617
"""
15891618
Draw samples of requested statespace matrices from provided idata
@@ -1600,12 +1629,18 @@ def sample_statespace_matrices(
16001629
group: str, one of "posterior" or "prior"
16011630
Whether to sample from priors or posteriors
16021631
1632+
kwargs:
1633+
Additional keyword arguments are passed to ``pymc.sample_posterior_predictive``
1634+
16031635
Returns
16041636
-------
16051637
idata_matrices: az.InterenceData
16061638
"""
16071639
_verify_group(group)
16081640

1641+
compile_kwargs = kwargs.pop("compile_kwargs", {})
1642+
compile_kwargs.setdefault("mode", self.mode)
1643+
16091644
if matrix_names is None:
16101645
matrix_names = MATRIX_NAMES
16111646
elif isinstance(matrix_names, str):
@@ -1636,8 +1671,9 @@ def sample_statespace_matrices(
16361671
matrix_idata = pm.sample_posterior_predictive(
16371672
idata if group == "posterior" else idata.prior,
16381673
var_names=matrix_names,
1639-
compile_kwargs={"mode": self._fit_mode},
16401674
extend_inferencedata=False,
1675+
compile_kwargs=compile_kwargs,
1676+
**kwargs,
16411677
)
16421678

16431679
return matrix_idata
@@ -2106,6 +2142,10 @@ def forecast(
21062142
filter_time_dim = TIME_DIM
21072143

21082144
_validate_filter_arg(filter_output)
2145+
2146+
compile_kwargs = kwargs.pop("compile_kwargs", {})
2147+
compile_kwargs.setdefault("mode", self.mode)
2148+
21092149
time_index = self._get_fit_time_index()
21102150

21112151
if start is None and verbose:
@@ -2192,7 +2232,6 @@ def forecast(
21922232
*matrices,
21932233
steps=len(forecast_index),
21942234
dims=dims,
2195-
mode=self._fit_mode,
21962235
sequence_names=self.kalman_filter.seq_names,
21972236
k_endog=self.k_endog,
21982237
append_x0=False,
@@ -2208,8 +2247,8 @@ def forecast(
22082247
idata_forecast = pm.sample_posterior_predictive(
22092248
idata,
22102249
var_names=["forecast_latent", "forecast_observed"],
2211-
compile_kwargs={"mode": self._fit_mode},
22122250
random_seed=random_seed,
2251+
compile_kwargs=compile_kwargs,
22132252
**kwargs,
22142253
)
22152254

@@ -2297,6 +2336,9 @@ def impulse_response_function(
22972336
n_options = sum(x is not None for x in options)
22982337
Q = None # No covariance matrix needed if a trajectory is provided. Will be overwritten later if needed.
22992338

2339+
compile_kwargs = kwargs.pop("compile_kwargs", {})
2340+
compile_kwargs.setdefault("mode", self.mode)
2341+
23002342
if n_options > 1:
23012343
raise ValueError("Specify exactly 0 or 1 of shock_size, shock_cov, or shock_trajectory")
23022344
elif n_options == 1:
@@ -2368,29 +2410,15 @@ def irf_step(shock, x, c, T, R):
23682410
non_sequences=[c, T, R],
23692411
n_steps=n_steps,
23702412
strict=True,
2371-
mode=self._fit_mode,
23722413
)
23732414

23742415
pm.Deterministic("irf", irf, dims=[TIME_DIM, ALL_STATE_DIM])
23752416

2376-
compile_kwargs = kwargs.get("compile_kwargs", {})
2377-
if "mode" not in compile_kwargs.keys():
2378-
compile_kwargs = {"mode": self._fit_mode}
2379-
else:
2380-
mode = compile_kwargs.get("mode")
2381-
if mode is not None and mode != self._fit_mode:
2382-
raise ValueError(
2383-
f"User provided compile mode ({mode}) does not match the compile mode used to "
2384-
f"construct the model ({self._fit_mode})."
2385-
)
2386-
2387-
compile_kwargs.update({"mode": self._fit_mode})
2388-
23892417
irf_idata = pm.sample_posterior_predictive(
23902418
idata,
23912419
var_names=["irf"],
2392-
compile_kwargs=compile_kwargs,
23932420
random_seed=random_seed,
2421+
compile_kwargs=compile_kwargs,
23942422
**kwargs,
23952423
)
23962424

pymc_extras/statespace/filters/distributions.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def __new__(
6969
H,
7070
Q,
7171
steps=None,
72-
mode=None,
7372
sequence_names=None,
7473
append_x0=True,
7574
method="svd",
@@ -98,7 +97,6 @@ def __new__(
9897
H,
9998
Q,
10099
steps=steps,
101-
mode=mode,
102100
sequence_names=sequence_names,
103101
append_x0=append_x0,
104102
method=method,
@@ -118,7 +116,6 @@ def dist(
118116
H,
119117
Q,
120118
steps=None,
121-
mode=None,
122119
sequence_names=None,
123120
append_x0=True,
124121
method="svd",
@@ -135,7 +132,6 @@ def dist(
135132

136133
return super().dist(
137134
[a0, P0, c, d, T, Z, R, H, Q, steps],
138-
mode=mode,
139135
sequence_names=sequence_names,
140136
append_x0=append_x0,
141137
method=method,
@@ -156,7 +152,6 @@ def rv_op(
156152
Q,
157153
steps,
158154
size=None,
159-
mode=None,
160155
sequence_names=None,
161156
append_x0=True,
162157
method="svd",
@@ -240,7 +235,6 @@ def step_fn(*args):
240235
sequences=None if len(sequences) == 0 else sequences,
241236
non_sequences=[*non_sequences, rng],
242237
n_steps=steps,
243-
mode=mode,
244238
strict=True,
245239
)
246240

@@ -284,7 +278,6 @@ def __new__(
284278
steps,
285279
k_endog=None,
286280
sequence_names=None,
287-
mode=None,
288281
append_x0=True,
289282
method="svd",
290283
**kwargs,
@@ -313,7 +306,6 @@ def __new__(
313306
H,
314307
Q,
315308
steps=steps,
316-
mode=mode,
317309
sequence_names=sequence_names,
318310
append_x0=append_x0,
319311
method=method,

0 commit comments

Comments
 (0)