Skip to content

Commit 81b4b7b

Browse files
Remove _set_default_mode
1 parent 5acca07 commit 81b4b7b

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,12 +1136,6 @@ def _kalman_filter_outputs_from_dummy_graph(
11361136

11371137
return [x0, P0, c, d, T, Z, R, H, Q], grouped_outputs
11381138

1139-
def _set_default_mode(self, compile_kwargs):
1140-
mode = compile_kwargs.get("mode", self.mode)
1141-
compile_kwargs["mode"] = mode
1142-
1143-
return compile_kwargs
1144-
11451139
def _sample_conditional(
11461140
self,
11471141
idata: InferenceData,
@@ -1194,7 +1188,7 @@ def _sample_conditional(
11941188
group_idata = getattr(idata, group)
11951189

11961190
compile_kwargs = kwargs.pop("compile_kwargs", {})
1197-
compile_kwargs = self._set_default_mode(compile_kwargs)
1191+
compile_kwargs.setdefault("mode", self.mode)
11981192

11991193
with pm.Model(coords=self._fit_coords) as forward_model:
12001194
(
@@ -1330,7 +1324,7 @@ def _sample_unconditional(
13301324
_verify_group(group)
13311325

13321326
compile_kwargs = kwargs.pop("compile_kwargs", {})
1333-
compile_kwargs = self._set_default_mode(compile_kwargs)
1327+
compile_kwargs.setdefault("mode", self.mode)
13341328

13351329
group_idata = getattr(idata, group)
13361330
dims = None
@@ -1645,7 +1639,7 @@ def sample_statespace_matrices(
16451639
_verify_group(group)
16461640

16471641
compile_kwargs = kwargs.pop("compile_kwargs", {})
1648-
compile_kwargs = self._set_default_mode(compile_kwargs)
1642+
compile_kwargs.setdefault("mode", self.mode)
16491643

16501644
if matrix_names is None:
16511645
matrix_names = MATRIX_NAMES
@@ -2150,7 +2144,7 @@ def forecast(
21502144
_validate_filter_arg(filter_output)
21512145

21522146
compile_kwargs = kwargs.pop("compile_kwargs", {})
2153-
compile_kwargs = self._set_default_mode(compile_kwargs)
2147+
compile_kwargs.setdefault("mode", self.mode)
21542148

21552149
time_index = self._get_fit_time_index()
21562150

@@ -2343,7 +2337,7 @@ def impulse_response_function(
23432337
Q = None # No covariance matrix needed if a trajectory is provided. Will be overwritten later if needed.
23442338

23452339
compile_kwargs = kwargs.pop("compile_kwargs", {})
2346-
compile_kwargs = self._set_default_mode(compile_kwargs)
2340+
compile_kwargs.setdefault("mode", self.mode)
23472341

23482342
if n_options > 1:
23492343
raise ValueError("Specify exactly 0 or 1 of shock_size, shock_cov, or shock_trajectory")

0 commit comments

Comments
 (0)