Skip to content

Commit 8c23692

Browse files
Add deprecation warning to mode argument of build_statespace_graph
Add mode argument to statespace constructors
1 parent 27a8752 commit 8c23692

File tree

9 files changed

+140
-4
lines changed

9 files changed

+140
-4
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import warnings
23

34
from collections.abc import Callable, Sequence
45
from typing import Any, Literal
@@ -98,6 +99,13 @@ class PyMCStateSpace:
9899
compute the observation errors. If False, these errors are deterministically zero; if True, they are sampled
99100
from a multivariate normal.
100101
102+
mode: str or Mode, optional
103+
Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
104+
``forecast``. The mode does **not** effect calls to ``pm.sample``.
105+
106+
Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
107+
to all sampling methods.
108+
101109
Notes
102110
-----
103111
Based on the statsmodels statespace implementation https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py,
@@ -220,6 +228,7 @@ def __init__(
220228
filter_type: str = "standard",
221229
verbose: bool = True,
222230
measurement_error: bool = False,
231+
mode: str | None = None,
223232
):
224233
self._fit_coords: dict[str, Sequence[str]] | None = None
225234
self._fit_dims: dict[str, Sequence[str]] | None = None
@@ -235,6 +244,7 @@ def __init__(
235244
self.k_states = k_states
236245
self.k_posdef = k_posdef
237246
self.measurement_error = measurement_error
247+
self.mode = mode
238248

239249
# All models contain a state space representation and a Kalman filter
240250
self.ssm = PytensorRepresentation(k_endog, k_states, k_posdef)
@@ -821,6 +831,7 @@ def build_statespace_graph(
821831
cov_jitter: float | None = JITTER_DEFAULT,
822832
mvn_method: Literal["cholesky", "eigh", "svd"] = "svd",
823833
save_kalman_filter_outputs_in_idata: bool = False,
834+
mode: str | None = None,
824835
) -> None:
825836
"""
826837
Given a parameter vector `theta`, constructs the full computational graph describing the state space model and
@@ -874,7 +885,25 @@ def build_statespace_graph(
874885
save_kalman_filter_outputs_in_idata: bool, optional, default=False
875886
If True, Kalman Filter outputs will be saved in the model as deterministics. Useful for debugging, but
876887
should not be necessary for the majority of users.
888+
889+
mode: str, optional
890+
Pytensor mode to use when compiling the graph. This will be saved as a model attribute and used when
891+
compiling sampling functions (e.g. ``sample_conditional_prior``).
892+
893+
.. deprecated:: 0.2.5
894+
The `mode` argument is deprecated and will be removed in a future version. Pass ``mode`` to the
895+
model constructor, or manually specify ``compile_kwargs`` in sampling functions instead.
896+
877897
"""
898+
if mode is not None:
899+
warnings.warn(
900+
"The `mode` argument is deprecated and will be removed in a future version. "
901+
"Pass `mode` to the model constructor, or manually specify `compile_kwargs` in sampling functions"
902+
" instead.",
903+
DeprecationWarning,
904+
)
905+
self.mode = mode
906+
878907
pm_mod = modelcontext(None)
879908

880909
self._insert_random_variables()
@@ -1107,6 +1136,12 @@ def _kalman_filter_outputs_from_dummy_graph(
11071136

11081137
return [x0, P0, c, d, T, Z, R, H, Q], grouped_outputs
11091138

1139+
def _set_default_mode(self, compile_kwargs):
1140+
mode = compile_kwargs.get("mode", self.mode)
1141+
compile_kwargs["mode"] = mode
1142+
1143+
return compile_kwargs
1144+
11101145
def _sample_conditional(
11111146
self,
11121147
idata: InferenceData,
@@ -1158,6 +1193,9 @@ def _sample_conditional(
11581193
_verify_group(group)
11591194
group_idata = getattr(idata, group)
11601195

1196+
compile_kwargs = kwargs.pop("compile_kwargs", {})
1197+
compile_kwargs = self._set_default_mode(compile_kwargs)
1198+
11611199
with pm.Model(coords=self._fit_coords) as forward_model:
11621200
(
11631201
[
@@ -1224,6 +1262,7 @@ def _sample_conditional(
12241262
for suffix in ["", "_observed"]
12251263
],
12261264
random_seed=random_seed,
1265+
compile_kwargs=compile_kwargs,
12271266
**kwargs,
12281267
)
12291268

@@ -1289,6 +1328,10 @@ def _sample_unconditional(
12891328
the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`.
12901329
"""
12911330
_verify_group(group)
1331+
1332+
compile_kwargs = kwargs.pop("compile_kwargs", {})
1333+
compile_kwargs = self._set_default_mode(compile_kwargs)
1334+
12921335
group_idata = getattr(idata, group)
12931336
dims = None
12941337
temp_coords = self._fit_coords.copy()
@@ -1347,6 +1390,7 @@ def _sample_unconditional(
13471390
group_idata,
13481391
var_names=[f"{group}_latent", f"{group}_observed"],
13491392
random_seed=random_seed,
1393+
compile_kwargs=compile_kwargs,
13501394
**kwargs,
13511395
)
13521396

@@ -1574,7 +1618,7 @@ def sample_unconditional_posterior(
15741618
)
15751619

15761620
def sample_statespace_matrices(
1577-
self, idata, matrix_names: str | list[str] | None, group: str = "posterior"
1621+
self, idata, matrix_names: str | list[str] | None, group: str = "posterior", **kwargs
15781622
):
15791623
"""
15801624
Draw samples of requested statespace matrices from provided idata
@@ -1591,12 +1635,18 @@ def sample_statespace_matrices(
15911635
group: str, one of "posterior" or "prior"
15921636
Whether to sample from priors or posteriors
15931637
1638+
kwargs:
1639+
Additional keyword arguments are passed to ``pymc.sample_posterior_predictive``
1640+
15941641
Returns
15951642
-------
15961643
idata_matrices: az.InterenceData
15971644
"""
15981645
_verify_group(group)
15991646

1647+
compile_kwargs = kwargs.pop("compile_kwargs", {})
1648+
compile_kwargs = self._set_default_mode(compile_kwargs)
1649+
16001650
if matrix_names is None:
16011651
matrix_names = MATRIX_NAMES
16021652
elif isinstance(matrix_names, str):
@@ -1628,6 +1678,8 @@ def sample_statespace_matrices(
16281678
idata if group == "posterior" else idata.prior,
16291679
var_names=matrix_names,
16301680
extend_inferencedata=False,
1681+
compile_kwargs=compile_kwargs,
1682+
**kwargs,
16311683
)
16321684

16331685
return matrix_idata
@@ -2096,6 +2148,10 @@ def forecast(
20962148
filter_time_dim = TIME_DIM
20972149

20982150
_validate_filter_arg(filter_output)
2151+
2152+
compile_kwargs = kwargs.pop("compile_kwargs", {})
2153+
compile_kwargs = self._set_default_mode(compile_kwargs)
2154+
20992155
time_index = self._get_fit_time_index()
21002156

21012157
if start is None and verbose:
@@ -2198,6 +2254,7 @@ def forecast(
21982254
idata,
21992255
var_names=["forecast_latent", "forecast_observed"],
22002256
random_seed=random_seed,
2257+
compile_kwargs=compile_kwargs,
22012258
**kwargs,
22022259
)
22032260

@@ -2285,6 +2342,9 @@ def impulse_response_function(
22852342
n_options = sum(x is not None for x in options)
22862343
Q = None # No covariance matrix needed if a trajectory is provided. Will be overwritten later if needed.
22872344

2345+
compile_kwargs = kwargs.pop("compile_kwargs", {})
2346+
compile_kwargs = self._set_default_mode(compile_kwargs)
2347+
22882348
if n_options > 1:
22892349
raise ValueError("Specify exactly 0 or 1 of shock_size, shock_cov, or shock_trajectory")
22902350
elif n_options == 1:
@@ -2364,6 +2424,7 @@ def irf_step(shock, x, c, T, R):
23642424
idata,
23652425
var_names=["irf"],
23662426
random_seed=random_seed,
2427+
compile_kwargs=compile_kwargs,
23672428
**kwargs,
23682429
)
23692430

pymc_extras/statespace/models/ETS.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytensor.tensor as pt
66

77
from pytensor import graph_replace
8+
from pytensor.compile.mode import Mode
89
from pytensor.tensor.slinalg import solve_discrete_lyapunov
910

1011
from pymc_extras.statespace.core.statespace import PyMCStateSpace, floatX
@@ -35,6 +36,7 @@ def __init__(
3536
initialization_dampening: float = 0.8,
3637
filter_type: str = "standard",
3738
verbose: bool = True,
39+
mode: str | Mode | None = None,
3840
):
3941
r"""
4042
Exponential Smoothing State Space Model
@@ -212,6 +214,13 @@ def __init__(
212214
and "cholesky". See the docs for kalman filters for more details.
213215
verbose: bool, default True
214216
If true, a message will be logged to the terminal explaining the variable names, dimensions, and supports.
217+
mode: str or Mode, optional
218+
Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
219+
``forecast``. The mode does **not** effect calls to ``pm.sample``.
220+
221+
Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
222+
to all sampling methods.
223+
215224
216225
References
217226
----------
@@ -284,6 +293,7 @@ def __init__(
284293
filter_type,
285294
verbose=verbose,
286295
measurement_error=measurement_error,
296+
mode=mode,
287297
)
288298

289299
@property

pymc_extras/statespace/models/SARIMAX.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import pytensor.tensor as pt
66

7+
from pytensor.compile.mode import Mode
78
from pytensor.tensor.slinalg import solve_discrete_lyapunov
89

910
from pymc_extras.statespace.core.statespace import PyMCStateSpace, floatX
@@ -91,6 +92,13 @@ class BayesianSARIMA(PyMCStateSpace):
9192
verbose: bool, default True
9293
If true, a message will be logged to the terminal explaining the variable names, dimensions, and supports.
9394
95+
mode: str or Mode, optional
96+
Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
97+
``forecast``. The mode does **not** effect calls to ``pm.sample``.
98+
99+
Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
100+
to all sampling methods.
101+
94102
Notes
95103
-----
96104
The ARIMAX model is a univariate time series model that posits the future evolution of a stationary time series will
@@ -180,7 +188,21 @@ def __init__(
180188
state_structure: str = "fast",
181189
measurement_error: bool = False,
182190
verbose=True,
191+
mode: str | Mode | None = None,
183192
):
193+
"""
194+
195+
Parameters
196+
----------
197+
order
198+
seasonal_order
199+
stationary_initialization
200+
filter_type
201+
state_structure
202+
measurement_error
203+
verbose
204+
mode
205+
"""
184206
# Model order
185207
self.p, self.d, self.q = order
186208
if seasonal_order is None:
@@ -228,6 +250,7 @@ def __init__(
228250
filter_type,
229251
verbose=verbose,
230252
measurement_error=measurement_error,
253+
mode=mode,
231254
)
232255

233256
@property

pymc_extras/statespace/models/VARMAX.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytensor
66
import pytensor.tensor as pt
77

8+
from pytensor.compile.mode import Mode
89
from pytensor.tensor.slinalg import solve_discrete_lyapunov
910

1011
from pymc_extras.statespace.core.statespace import PyMCStateSpace
@@ -72,6 +73,13 @@ class BayesianVARMAX(PyMCStateSpace):
7273
verbose: bool, default True
7374
If true, a message will be logged to the terminal explaining the variable names, dimensions, and supports.
7475
76+
mode: str or Mode, optional
77+
Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
78+
``forecast``. The mode does **not** effect calls to ``pm.sample``.
79+
80+
Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
81+
to all sampling methods.
82+
7583
Notes
7684
-----
7785
The VARMA model is a multivariate extension of the SARIMAX model. Given a set of timeseries :math:`\{x_t\}_{t=0}^T`,
@@ -147,7 +155,8 @@ def __init__(
147155
stationary_initialization: bool = False,
148156
filter_type: str = "standard",
149157
measurement_error: bool = False,
150-
verbose=True,
158+
verbose: bool = True,
159+
mode: str | Mode | None = None,
151160
):
152161
if (endog_names is None) and (k_endog is None):
153162
raise ValueError("Must specify either endog_names or k_endog")
@@ -174,6 +183,7 @@ def __init__(
174183
filter_type,
175184
verbose=verbose,
176185
measurement_error=measurement_error,
186+
mode=mode,
177187
)
178188

179189
# Save counts of the number of parameters in each category

pymc_extras/statespace/models/structural.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import xarray as xr
1313

1414
from pytensor import Variable
15+
from pytensor.compile.mode import Mode
1516

1617
from pymc_extras.statespace.core import PytensorRepresentation
1718
from pymc_extras.statespace.core.statespace import PyMCStateSpace
@@ -81,6 +82,7 @@ def __init__(
8182
name: str | None = None,
8283
verbose: bool = True,
8384
filter_type: str = "standard",
85+
mode: str | Mode | None = None,
8486
):
8587
# Add the initial state covariance to the parameters
8688
if name is None:
@@ -112,6 +114,7 @@ def __init__(
112114
filter_type=filter_type,
113115
verbose=verbose,
114116
measurement_error=measurement_error,
117+
mode=mode,
115118
)
116119
self.ssm = ssm.copy()
117120

@@ -644,7 +647,9 @@ def __add__(self, other):
644647

645648
return new_comp
646649

647-
def build(self, name=None, filter_type="standard", verbose=True):
650+
def build(
651+
self, name=None, filter_type="standard", verbose=True, mode: str | Mode | None = None
652+
):
648653
"""
649654
Build a StructuralTimeSeries statespace model from the current component(s)
650655
@@ -660,6 +665,13 @@ def build(self, name=None, filter_type="standard", verbose=True):
660665
verbose : bool, optional
661666
If True, displays information about the initialized model. Defaults to True.
662667
668+
mode: str or Mode, optional
669+
Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
670+
``forecast``. The mode does **not** effect calls to ``pm.sample``.
671+
672+
Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
673+
to all sampling methods.
674+
663675
Returns
664676
-------
665677
PyMCStateSpace
@@ -685,6 +697,7 @@ def build(self, name=None, filter_type="standard", verbose=True):
685697
name_to_data=self._name_to_data,
686698
filter_type=filter_type,
687699
verbose=verbose,
700+
mode=mode,
688701
)
689702

690703

tests/statespace/models/test_ETS.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ def test_order_flags(order, expected_flags):
8989
assert getattr(mod, key) == value
9090

9191

92+
def tests_mode_argument():
93+
# Mode argument should be passed to the parent class
94+
mod = BayesianETS(order=("A", "N", "N"), mode="FAST_RUN")
95+
assert mod.mode == "FAST_RUN"
96+
97+
9298
@pytest.mark.parametrize("order, expected_params", zip(orders, order_params), ids=order_names)
9399
def test_param_info(order: tuple[str, str, str], expected_params):
94100
mod = BayesianETS(order=order, seasonal_periods=4)

0 commit comments

Comments
 (0)