1
1
import logging
2
+ import warnings
2
3
3
4
from collections .abc import Callable , Sequence
4
5
from typing import Any , Literal
14
15
from pymc .model .transform .optimization import freeze_dims_and_data
15
16
from pymc .util import RandomState
16
17
from pytensor import Variable , graph_replace
17
- from pytensor .compile import get_mode
18
18
from rich .box import SIMPLE_HEAD
19
19
from rich .console import Console
20
20
from rich .table import Table
@@ -99,6 +99,13 @@ class PyMCStateSpace:
99
99
compute the observation errors. If False, these errors are deterministically zero; if True, they are sampled
100
100
from a multivariate normal.
101
101
102
+ mode: str or Mode, optional
103
+ Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
104
+ ``forecast``. The mode does **not** effect calls to ``pm.sample``.
105
+
106
+ Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
107
+ to all sampling methods.
108
+
102
109
Notes
103
110
-----
104
111
Based on the statsmodels statespace implementation https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py,
@@ -221,8 +228,8 @@ def __init__(
221
228
filter_type : str = "standard" ,
222
229
verbose : bool = True ,
223
230
measurement_error : bool = False ,
231
+ mode : str | None = None ,
224
232
):
225
- self ._fit_mode : str | None = None
226
233
self ._fit_coords : dict [str , Sequence [str ]] | None = None
227
234
self ._fit_dims : dict [str , Sequence [str ]] | None = None
228
235
self ._fit_data : pt .TensorVariable | None = None
@@ -237,6 +244,7 @@ def __init__(
237
244
self .k_states = k_states
238
245
self .k_posdef = k_posdef
239
246
self .measurement_error = measurement_error
247
+ self .mode = mode
240
248
241
249
# All models contain a state space representation and a Kalman filter
242
250
self .ssm = PytensorRepresentation (k_endog , k_states , k_posdef )
@@ -819,11 +827,11 @@ def build_statespace_graph(
819
827
self ,
820
828
data : np .ndarray | pd .DataFrame | pt .TensorVariable ,
821
829
register_data : bool = True ,
822
- mode : str | None = None ,
823
830
missing_fill_value : float | None = None ,
824
831
cov_jitter : float | None = JITTER_DEFAULT ,
825
832
mvn_method : Literal ["cholesky" , "eigh" , "svd" ] = "svd" ,
826
833
save_kalman_filter_outputs_in_idata : bool = False ,
834
+ mode : str | None = None ,
827
835
) -> None :
828
836
"""
829
837
Given a parameter vector `theta`, constructs the full computational graph describing the state space model and
@@ -877,7 +885,25 @@ def build_statespace_graph(
877
885
save_kalman_filter_outputs_in_idata: bool, optional, default=False
878
886
If True, Kalman Filter outputs will be saved in the model as deterministics. Useful for debugging, but
879
887
should not be necessary for the majority of users.
888
+
889
+ mode: str, optional
890
+ Pytensor mode to use when compiling the graph. This will be saved as a model attribute and used when
891
+ compiling sampling functions (e.g. ``sample_conditional_prior``).
892
+
893
+ .. deprecated:: 0.2.5
894
+ The `mode` argument is deprecated and will be removed in a future version. Pass ``mode`` to the
895
+ model constructor, or manually specify ``compile_kwargs`` in sampling functions instead.
896
+
880
897
"""
898
+ if mode is not None :
899
+ warnings .warn (
900
+ "The `mode` argument is deprecated and will be removed in a future version. "
901
+ "Pass `mode` to the model constructor, or manually specify `compile_kwargs` in sampling functions"
902
+ " instead." ,
903
+ DeprecationWarning ,
904
+ )
905
+ self .mode = mode
906
+
881
907
pm_mod = modelcontext (None )
882
908
883
909
self ._insert_random_variables ()
@@ -898,7 +924,6 @@ def build_statespace_graph(
898
924
filter_outputs = self .kalman_filter .build_graph (
899
925
pt .as_tensor_variable (data ),
900
926
* self .unpack_statespace (),
901
- mode = mode ,
902
927
missing_fill_value = missing_fill_value ,
903
928
cov_jitter = cov_jitter ,
904
929
)
@@ -909,7 +934,7 @@ def build_statespace_graph(
909
934
filtered_covariances , predicted_covariances , observed_covariances = covs
910
935
if save_kalman_filter_outputs_in_idata :
911
936
smooth_states , smooth_covariances = self ._build_smoother_graph (
912
- filtered_states , filtered_covariances , self .unpack_statespace (), mode = mode
937
+ filtered_states , filtered_covariances , self .unpack_statespace ()
913
938
)
914
939
all_kf_outputs = [* states , smooth_states , * covs , smooth_covariances ]
915
940
self ._register_kalman_filter_outputs_with_pymc_model (all_kf_outputs )
@@ -929,7 +954,6 @@ def build_statespace_graph(
929
954
930
955
self ._fit_coords = pm_mod .coords .copy ()
931
956
self ._fit_dims = pm_mod .named_vars_to_dims .copy ()
932
- self ._fit_mode = mode
933
957
934
958
def _build_smoother_graph (
935
959
self ,
@@ -974,7 +998,7 @@ def _build_smoother_graph(
974
998
* _ , T , Z , R , H , Q = matrices
975
999
976
1000
smooth_states , smooth_covariances = self .kalman_smoother .build_graph (
977
- T , R , Q , filtered_states , filtered_covariances , mode = mode , cov_jitter = cov_jitter
1001
+ T , R , Q , filtered_states , filtered_covariances , cov_jitter = cov_jitter
978
1002
)
979
1003
smooth_states .name = "smooth_states"
980
1004
smooth_covariances .name = "smooth_covariances"
@@ -1092,7 +1116,6 @@ def _kalman_filter_outputs_from_dummy_graph(
1092
1116
R ,
1093
1117
H ,
1094
1118
Q ,
1095
- mode = self ._fit_mode ,
1096
1119
)
1097
1120
1098
1121
filter_outputs .pop (- 1 )
@@ -1102,7 +1125,7 @@ def _kalman_filter_outputs_from_dummy_graph(
1102
1125
filtered_covariances , predicted_covariances , _ = covariances
1103
1126
1104
1127
[smoothed_states , smoothed_covariances ] = self .kalman_smoother .build_graph (
1105
- T , R , Q , filtered_states , filtered_covariances , mode = self . _fit_mode
1128
+ T , R , Q , filtered_states , filtered_covariances
1106
1129
)
1107
1130
1108
1131
grouped_outputs = [
@@ -1164,6 +1187,9 @@ def _sample_conditional(
1164
1187
_verify_group (group )
1165
1188
group_idata = getattr (idata , group )
1166
1189
1190
+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
1191
+ compile_kwargs .setdefault ("mode" , self .mode )
1192
+
1167
1193
with pm .Model (coords = self ._fit_coords ) as forward_model :
1168
1194
(
1169
1195
[
@@ -1229,8 +1255,8 @@ def _sample_conditional(
1229
1255
for name in FILTER_OUTPUT_TYPES
1230
1256
for suffix in ["" , "_observed" ]
1231
1257
],
1232
- compile_kwargs = {"mode" : get_mode (self ._fit_mode )},
1233
1258
random_seed = random_seed ,
1259
+ compile_kwargs = compile_kwargs ,
1234
1260
** kwargs ,
1235
1261
)
1236
1262
@@ -1296,6 +1322,10 @@ def _sample_unconditional(
1296
1322
the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`.
1297
1323
"""
1298
1324
_verify_group (group )
1325
+
1326
+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
1327
+ compile_kwargs .setdefault ("mode" , self .mode )
1328
+
1299
1329
group_idata = getattr (idata , group )
1300
1330
dims = None
1301
1331
temp_coords = self ._fit_coords .copy ()
@@ -1338,7 +1368,6 @@ def _sample_unconditional(
1338
1368
* matrices ,
1339
1369
steps = steps ,
1340
1370
dims = dims ,
1341
- mode = self ._fit_mode ,
1342
1371
method = mvn_method ,
1343
1372
sequence_names = self .kalman_filter .seq_names ,
1344
1373
k_endog = self .k_endog ,
@@ -1354,8 +1383,8 @@ def _sample_unconditional(
1354
1383
idata_unconditional = pm .sample_posterior_predictive (
1355
1384
group_idata ,
1356
1385
var_names = [f"{ group } _latent" , f"{ group } _observed" ],
1357
- compile_kwargs = {"mode" : self ._fit_mode },
1358
1386
random_seed = random_seed ,
1387
+ compile_kwargs = compile_kwargs ,
1359
1388
** kwargs ,
1360
1389
)
1361
1390
@@ -1583,7 +1612,7 @@ def sample_unconditional_posterior(
1583
1612
)
1584
1613
1585
1614
def sample_statespace_matrices (
1586
- self , idata , matrix_names : str | list [str ] | None , group : str = "posterior"
1615
+ self , idata , matrix_names : str | list [str ] | None , group : str = "posterior" , ** kwargs
1587
1616
):
1588
1617
"""
1589
1618
Draw samples of requested statespace matrices from provided idata
@@ -1600,12 +1629,18 @@ def sample_statespace_matrices(
1600
1629
group: str, one of "posterior" or "prior"
1601
1630
Whether to sample from priors or posteriors
1602
1631
1632
+ kwargs:
1633
+ Additional keyword arguments are passed to ``pymc.sample_posterior_predictive``
1634
+
1603
1635
Returns
1604
1636
-------
1605
1637
idata_matrices: az.InterenceData
1606
1638
"""
1607
1639
_verify_group (group )
1608
1640
1641
+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
1642
+ compile_kwargs .setdefault ("mode" , self .mode )
1643
+
1609
1644
if matrix_names is None :
1610
1645
matrix_names = MATRIX_NAMES
1611
1646
elif isinstance (matrix_names , str ):
@@ -1636,8 +1671,9 @@ def sample_statespace_matrices(
1636
1671
matrix_idata = pm .sample_posterior_predictive (
1637
1672
idata if group == "posterior" else idata .prior ,
1638
1673
var_names = matrix_names ,
1639
- compile_kwargs = {"mode" : self ._fit_mode },
1640
1674
extend_inferencedata = False ,
1675
+ compile_kwargs = compile_kwargs ,
1676
+ ** kwargs ,
1641
1677
)
1642
1678
1643
1679
return matrix_idata
@@ -2106,6 +2142,10 @@ def forecast(
2106
2142
filter_time_dim = TIME_DIM
2107
2143
2108
2144
_validate_filter_arg (filter_output )
2145
+
2146
+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
2147
+ compile_kwargs .setdefault ("mode" , self .mode )
2148
+
2109
2149
time_index = self ._get_fit_time_index ()
2110
2150
2111
2151
if start is None and verbose :
@@ -2192,7 +2232,6 @@ def forecast(
2192
2232
* matrices ,
2193
2233
steps = len (forecast_index ),
2194
2234
dims = dims ,
2195
- mode = self ._fit_mode ,
2196
2235
sequence_names = self .kalman_filter .seq_names ,
2197
2236
k_endog = self .k_endog ,
2198
2237
append_x0 = False ,
@@ -2208,8 +2247,8 @@ def forecast(
2208
2247
idata_forecast = pm .sample_posterior_predictive (
2209
2248
idata ,
2210
2249
var_names = ["forecast_latent" , "forecast_observed" ],
2211
- compile_kwargs = {"mode" : self ._fit_mode },
2212
2250
random_seed = random_seed ,
2251
+ compile_kwargs = compile_kwargs ,
2213
2252
** kwargs ,
2214
2253
)
2215
2254
@@ -2297,6 +2336,9 @@ def impulse_response_function(
2297
2336
n_options = sum (x is not None for x in options )
2298
2337
Q = None # No covariance matrix needed if a trajectory is provided. Will be overwritten later if needed.
2299
2338
2339
+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
2340
+ compile_kwargs .setdefault ("mode" , self .mode )
2341
+
2300
2342
if n_options > 1 :
2301
2343
raise ValueError ("Specify exactly 0 or 1 of shock_size, shock_cov, or shock_trajectory" )
2302
2344
elif n_options == 1 :
@@ -2368,29 +2410,15 @@ def irf_step(shock, x, c, T, R):
2368
2410
non_sequences = [c , T , R ],
2369
2411
n_steps = n_steps ,
2370
2412
strict = True ,
2371
- mode = self ._fit_mode ,
2372
2413
)
2373
2414
2374
2415
pm .Deterministic ("irf" , irf , dims = [TIME_DIM , ALL_STATE_DIM ])
2375
2416
2376
- compile_kwargs = kwargs .get ("compile_kwargs" , {})
2377
- if "mode" not in compile_kwargs .keys ():
2378
- compile_kwargs = {"mode" : self ._fit_mode }
2379
- else :
2380
- mode = compile_kwargs .get ("mode" )
2381
- if mode is not None and mode != self ._fit_mode :
2382
- raise ValueError (
2383
- f"User provided compile mode ({ mode } ) does not match the compile mode used to "
2384
- f"construct the model ({ self ._fit_mode } )."
2385
- )
2386
-
2387
- compile_kwargs .update ({"mode" : self ._fit_mode })
2388
-
2389
2417
irf_idata = pm .sample_posterior_predictive (
2390
2418
idata ,
2391
2419
var_names = ["irf" ],
2392
- compile_kwargs = compile_kwargs ,
2393
2420
random_seed = random_seed ,
2421
+ compile_kwargs = compile_kwargs ,
2394
2422
** kwargs ,
2395
2423
)
2396
2424
0 commit comments