1
1
import logging
2
+ import warnings
2
3
3
4
from collections .abc import Callable , Sequence
4
5
from typing import Any , Literal
@@ -98,6 +99,13 @@ class PyMCStateSpace:
98
99
compute the observation errors. If False, these errors are deterministically zero; if True, they are sampled
99
100
from a multivariate normal.
100
101
102
+ mode: str or Mode, optional
103
+ Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
104
+ ``forecast``. The mode does **not** effect calls to ``pm.sample``.
105
+
106
+ Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
107
+ to all sampling methods.
108
+
101
109
Notes
102
110
-----
103
111
Based on the statsmodels statespace implementation https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py,
@@ -220,6 +228,7 @@ def __init__(
220
228
filter_type : str = "standard" ,
221
229
verbose : bool = True ,
222
230
measurement_error : bool = False ,
231
+ mode : str | None = None ,
223
232
):
224
233
self ._fit_coords : dict [str , Sequence [str ]] | None = None
225
234
self ._fit_dims : dict [str , Sequence [str ]] | None = None
@@ -235,6 +244,7 @@ def __init__(
235
244
self .k_states = k_states
236
245
self .k_posdef = k_posdef
237
246
self .measurement_error = measurement_error
247
+ self .mode = mode
238
248
239
249
# All models contain a state space representation and a Kalman filter
240
250
self .ssm = PytensorRepresentation (k_endog , k_states , k_posdef )
@@ -821,6 +831,7 @@ def build_statespace_graph(
821
831
cov_jitter : float | None = JITTER_DEFAULT ,
822
832
mvn_method : Literal ["cholesky" , "eigh" , "svd" ] = "svd" ,
823
833
save_kalman_filter_outputs_in_idata : bool = False ,
834
+ mode : str | None = None ,
824
835
) -> None :
825
836
"""
826
837
Given a parameter vector `theta`, constructs the full computational graph describing the state space model and
@@ -874,7 +885,25 @@ def build_statespace_graph(
874
885
save_kalman_filter_outputs_in_idata: bool, optional, default=False
875
886
If True, Kalman Filter outputs will be saved in the model as deterministics. Useful for debugging, but
876
887
should not be necessary for the majority of users.
888
+
889
+ mode: str, optional
890
+ Pytensor mode to use when compiling the graph. This will be saved as a model attribute and used when
891
+ compiling sampling functions (e.g. ``sample_conditional_prior``).
892
+
893
+ .. deprecated:: 0.2.5
894
+ The `mode` argument is deprecated and will be removed in a future version. Pass ``mode`` to the
895
+ model constructor, or manually specify ``compile_kwargs`` in sampling functions instead.
896
+
877
897
"""
898
+ if mode is not None :
899
+ warnings .warn (
900
+ "The `mode` argument is deprecated and will be removed in a future version. "
901
+ "Pass `mode` to the model constructor, or manually specify `compile_kwargs` in sampling functions"
902
+ " instead." ,
903
+ DeprecationWarning ,
904
+ )
905
+ self .mode = mode
906
+
878
907
pm_mod = modelcontext (None )
879
908
880
909
self ._insert_random_variables ()
@@ -1107,6 +1136,12 @@ def _kalman_filter_outputs_from_dummy_graph(
1107
1136
1108
1137
return [x0 , P0 , c , d , T , Z , R , H , Q ], grouped_outputs
1109
1138
1139
+ def _set_default_mode (self , compile_kwargs ):
1140
+ mode = compile_kwargs .get ("mode" , self .mode )
1141
+ compile_kwargs ["mode" ] = mode
1142
+
1143
+ return compile_kwargs
1144
+
1110
1145
def _sample_conditional (
1111
1146
self ,
1112
1147
idata : InferenceData ,
@@ -1158,6 +1193,9 @@ def _sample_conditional(
1158
1193
_verify_group (group )
1159
1194
group_idata = getattr (idata , group )
1160
1195
1196
+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
1197
+ compile_kwargs = self ._set_default_mode (compile_kwargs )
1198
+
1161
1199
with pm .Model (coords = self ._fit_coords ) as forward_model :
1162
1200
(
1163
1201
[
@@ -1224,6 +1262,7 @@ def _sample_conditional(
1224
1262
for suffix in ["" , "_observed" ]
1225
1263
],
1226
1264
random_seed = random_seed ,
1265
+ compile_kwargs = compile_kwargs ,
1227
1266
** kwargs ,
1228
1267
)
1229
1268
@@ -1289,6 +1328,10 @@ def _sample_unconditional(
1289
1328
the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`.
1290
1329
"""
1291
1330
_verify_group (group )
1331
+
1332
+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
1333
+ compile_kwargs = self ._set_default_mode (compile_kwargs )
1334
+
1292
1335
group_idata = getattr (idata , group )
1293
1336
dims = None
1294
1337
temp_coords = self ._fit_coords .copy ()
@@ -1347,6 +1390,7 @@ def _sample_unconditional(
1347
1390
group_idata ,
1348
1391
var_names = [f"{ group } _latent" , f"{ group } _observed" ],
1349
1392
random_seed = random_seed ,
1393
+ compile_kwargs = compile_kwargs ,
1350
1394
** kwargs ,
1351
1395
)
1352
1396
@@ -1574,7 +1618,7 @@ def sample_unconditional_posterior(
1574
1618
)
1575
1619
1576
1620
def sample_statespace_matrices (
1577
- self , idata , matrix_names : str | list [str ] | None , group : str = "posterior"
1621
+ self , idata , matrix_names : str | list [str ] | None , group : str = "posterior" , ** kwargs
1578
1622
):
1579
1623
"""
1580
1624
Draw samples of requested statespace matrices from provided idata
@@ -1591,12 +1635,18 @@ def sample_statespace_matrices(
1591
1635
group: str, one of "posterior" or "prior"
1592
1636
Whether to sample from priors or posteriors
1593
1637
1638
+ kwargs:
1639
+ Additional keyword arguments are passed to ``pymc.sample_posterior_predictive``
1640
+
1594
1641
Returns
1595
1642
-------
1596
1643
idata_matrices: az.InterenceData
1597
1644
"""
1598
1645
_verify_group (group )
1599
1646
1647
+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
1648
+ compile_kwargs = self ._set_default_mode (compile_kwargs )
1649
+
1600
1650
if matrix_names is None :
1601
1651
matrix_names = MATRIX_NAMES
1602
1652
elif isinstance (matrix_names , str ):
@@ -1628,6 +1678,8 @@ def sample_statespace_matrices(
1628
1678
idata if group == "posterior" else idata .prior ,
1629
1679
var_names = matrix_names ,
1630
1680
extend_inferencedata = False ,
1681
+ compile_kwargs = compile_kwargs ,
1682
+ ** kwargs ,
1631
1683
)
1632
1684
1633
1685
return matrix_idata
@@ -2096,6 +2148,10 @@ def forecast(
2096
2148
filter_time_dim = TIME_DIM
2097
2149
2098
2150
_validate_filter_arg (filter_output )
2151
+
2152
+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
2153
+ compile_kwargs = self ._set_default_mode (compile_kwargs )
2154
+
2099
2155
time_index = self ._get_fit_time_index ()
2100
2156
2101
2157
if start is None and verbose :
@@ -2198,6 +2254,7 @@ def forecast(
2198
2254
idata ,
2199
2255
var_names = ["forecast_latent" , "forecast_observed" ],
2200
2256
random_seed = random_seed ,
2257
+ compile_kwargs = compile_kwargs ,
2201
2258
** kwargs ,
2202
2259
)
2203
2260
@@ -2285,6 +2342,9 @@ def impulse_response_function(
2285
2342
n_options = sum (x is not None for x in options )
2286
2343
Q = None # No covariance matrix needed if a trajectory is provided. Will be overwritten later if needed.
2287
2344
2345
+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
2346
+ compile_kwargs = self ._set_default_mode (compile_kwargs )
2347
+
2288
2348
if n_options > 1 :
2289
2349
raise ValueError ("Specify exactly 0 or 1 of shock_size, shock_cov, or shock_trajectory" )
2290
2350
elif n_options == 1 :
@@ -2364,6 +2424,7 @@ def irf_step(shock, x, c, T, R):
2364
2424
idata ,
2365
2425
var_names = ["irf" ],
2366
2426
random_seed = random_seed ,
2427
+ compile_kwargs = compile_kwargs ,
2367
2428
** kwargs ,
2368
2429
)
2369
2430
0 commit comments