Skip to content

Commit e2b5dd0

Browse files
Remove mode argument from everywhere
1 parent dacb6f4 commit e2b5dd0

File tree

11 files changed

+19
-82
lines changed

11 files changed

+19
-82
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from pymc.model.transform.optimization import freeze_dims_and_data
1515
from pymc.util import RandomState
1616
from pytensor import Variable, graph_replace
17-
from pytensor.compile import get_mode
1817
from rich.box import SIMPLE_HEAD
1918
from rich.console import Console
2019
from rich.table import Table
@@ -222,7 +221,6 @@ def __init__(
222221
verbose: bool = True,
223222
measurement_error: bool = False,
224223
):
225-
self._fit_mode: str | None = None
226224
self._fit_coords: dict[str, Sequence[str]] | None = None
227225
self._fit_dims: dict[str, Sequence[str]] | None = None
228226
self._fit_data: pt.TensorVariable | None = None
@@ -819,7 +817,6 @@ def build_statespace_graph(
819817
self,
820818
data: np.ndarray | pd.DataFrame | pt.TensorVariable,
821819
register_data: bool = True,
822-
mode: str | None = None,
823820
missing_fill_value: float | None = None,
824821
cov_jitter: float | None = JITTER_DEFAULT,
825822
save_kalman_filter_outputs_in_idata: bool = False,
@@ -889,7 +886,6 @@ def build_statespace_graph(
889886
filter_outputs = self.kalman_filter.build_graph(
890887
pt.as_tensor_variable(data),
891888
*self.unpack_statespace(),
892-
mode=mode,
893889
missing_fill_value=missing_fill_value,
894890
cov_jitter=cov_jitter,
895891
)
@@ -900,7 +896,7 @@ def build_statespace_graph(
900896
filtered_covariances, predicted_covariances, observed_covariances = covs
901897
if save_kalman_filter_outputs_in_idata:
902898
smooth_states, smooth_covariances = self._build_smoother_graph(
903-
filtered_states, filtered_covariances, self.unpack_statespace(), mode=mode
899+
filtered_states, filtered_covariances, self.unpack_statespace()
904900
)
905901
all_kf_outputs = [*states, smooth_states, *covs, smooth_covariances]
906902
self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs)
@@ -919,7 +915,6 @@ def build_statespace_graph(
919915

920916
self._fit_coords = pm_mod.coords.copy()
921917
self._fit_dims = pm_mod.named_vars_to_dims.copy()
922-
self._fit_mode = mode
923918

924919
def _build_smoother_graph(
925920
self,
@@ -964,7 +959,7 @@ def _build_smoother_graph(
964959
*_, T, Z, R, H, Q = matrices
965960

966961
smooth_states, smooth_covariances = self.kalman_smoother.build_graph(
967-
T, R, Q, filtered_states, filtered_covariances, mode=mode, cov_jitter=cov_jitter
962+
T, R, Q, filtered_states, filtered_covariances, cov_jitter=cov_jitter
968963
)
969964
smooth_states.name = "smooth_states"
970965
smooth_covariances.name = "smooth_covariances"
@@ -1082,7 +1077,6 @@ def _kalman_filter_outputs_from_dummy_graph(
10821077
R,
10831078
H,
10841079
Q,
1085-
mode=self._fit_mode,
10861080
)
10871081

10881082
filter_outputs.pop(-1)
@@ -1092,7 +1086,7 @@ def _kalman_filter_outputs_from_dummy_graph(
10921086
filtered_covariances, predicted_covariances, _ = covariances
10931087

10941088
[smoothed_states, smoothed_covariances] = self.kalman_smoother.build_graph(
1095-
T, R, Q, filtered_states, filtered_covariances, mode=self._fit_mode
1089+
T, R, Q, filtered_states, filtered_covariances
10961090
)
10971091

10981092
grouped_outputs = [
@@ -1208,7 +1202,6 @@ def _sample_conditional(
12081202
for name in FILTER_OUTPUT_TYPES
12091203
for suffix in ["", "_observed"]
12101204
],
1211-
compile_kwargs={"mode": get_mode(self._fit_mode)},
12121205
random_seed=random_seed,
12131206
**kwargs,
12141207
)
@@ -1308,7 +1301,6 @@ def _sample_unconditional(
13081301
*matrices,
13091302
steps=steps,
13101303
dims=dims,
1311-
mode=self._fit_mode,
13121304
sequence_names=self.kalman_filter.seq_names,
13131305
k_endog=self.k_endog,
13141306
)
@@ -1323,7 +1315,6 @@ def _sample_unconditional(
13231315
idata_unconditional = pm.sample_posterior_predictive(
13241316
group_idata,
13251317
var_names=[f"{group}_latent", f"{group}_observed"],
1326-
compile_kwargs={"mode": self._fit_mode},
13271318
random_seed=random_seed,
13281319
**kwargs,
13291320
)
@@ -1547,7 +1538,6 @@ def sample_statespace_matrices(
15471538
matrix_idata = pm.sample_posterior_predictive(
15481539
idata if group == "posterior" else idata.prior,
15491540
var_names=matrix_names,
1550-
compile_kwargs={"mode": self._fit_mode},
15511541
extend_inferencedata=False,
15521542
)
15531543

@@ -2094,7 +2084,6 @@ def forecast(
20942084
*matrices,
20952085
steps=len(forecast_index),
20962086
dims=dims,
2097-
mode=self._fit_mode,
20982087
sequence_names=self.kalman_filter.seq_names,
20992088
k_endog=self.k_endog,
21002089
append_x0=False,
@@ -2109,7 +2098,6 @@ def forecast(
21092098
idata_forecast = pm.sample_posterior_predictive(
21102099
idata,
21112100
var_names=["forecast_latent", "forecast_observed"],
2112-
compile_kwargs={"mode": self._fit_mode},
21132101
random_seed=random_seed,
21142102
**kwargs,
21152103
)
@@ -2260,28 +2248,13 @@ def irf_step(shock, x, c, T, R):
22602248
non_sequences=[c, T, R],
22612249
n_steps=n_steps,
22622250
strict=True,
2263-
mode=self._fit_mode,
22642251
)
22652252

22662253
pm.Deterministic("irf", irf, dims=[TIME_DIM, ALL_STATE_DIM])
22672254

2268-
compile_kwargs = kwargs.get("compile_kwargs", {})
2269-
if "mode" not in compile_kwargs.keys():
2270-
compile_kwargs = {"mode": self._fit_mode}
2271-
else:
2272-
mode = compile_kwargs.get("mode")
2273-
if mode is not None and mode != self._fit_mode:
2274-
raise ValueError(
2275-
f"User provided compile mode ({mode}) does not match the compile mode used to "
2276-
f"construct the model ({self._fit_mode})."
2277-
)
2278-
2279-
compile_kwargs.update({"mode": self._fit_mode})
2280-
22812255
irf_idata = pm.sample_posterior_predictive(
22822256
idata,
22832257
var_names=["irf"],
2284-
compile_kwargs=compile_kwargs,
22852258
random_seed=random_seed,
22862259
**kwargs,
22872260
)

pymc_extras/statespace/filters/distributions.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def __new__(
6969
H,
7070
Q,
7171
steps=None,
72-
mode=None,
7372
sequence_names=None,
7473
append_x0=True,
7574
**kwargs,
@@ -97,7 +96,6 @@ def __new__(
9796
H,
9897
Q,
9998
steps=steps,
100-
mode=mode,
10199
sequence_names=sequence_names,
102100
append_x0=append_x0,
103101
**kwargs,
@@ -116,7 +114,6 @@ def dist(
116114
H,
117115
Q,
118116
steps=None,
119-
mode=None,
120117
sequence_names=None,
121118
append_x0=True,
122119
**kwargs,
@@ -132,7 +129,6 @@ def dist(
132129

133130
return super().dist(
134131
[a0, P0, c, d, T, Z, R, H, Q, steps],
135-
mode=mode,
136132
sequence_names=sequence_names,
137133
append_x0=append_x0,
138134
**kwargs,
@@ -152,7 +148,6 @@ def rv_op(
152148
Q,
153149
steps,
154150
size=None,
155-
mode=None,
156151
sequence_names=None,
157152
append_x0=True,
158153
):
@@ -235,7 +230,6 @@ def step_fn(*args):
235230
sequences=None if len(sequences) == 0 else sequences,
236231
non_sequences=[*non_sequences, rng],
237232
n_steps=steps,
238-
mode=mode,
239233
strict=True,
240234
)
241235

@@ -279,7 +273,6 @@ def __new__(
279273
steps,
280274
k_endog=None,
281275
sequence_names=None,
282-
mode=None,
283276
append_x0=True,
284277
**kwargs,
285278
):
@@ -307,7 +300,6 @@ def __new__(
307300
H,
308301
Q,
309302
steps=steps,
310-
mode=mode,
311303
sequence_names=sequence_names,
312304
append_x0=append_x0,
313305
**kwargs,

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pytensor.tensor as pt
66

77
from pymc.pytensorf import constant_fold
8-
from pytensor.compile.mode import get_mode
98
from pytensor.graph.basic import Variable
109
from pytensor.raise_op import Assert
1110
from pytensor.tensor import TensorVariable
@@ -28,25 +27,17 @@
2827

2928

3029
class BaseFilter(ABC):
31-
def __init__(self, mode=None):
30+
def __init__(self):
3231
"""
3332
Kalman Filter.
3433
35-
Parameters
36-
----------
37-
mode : str, optional
38-
The mode used for Pytensor compilation. Defaults to None.
39-
4034
Notes
4135
-----
4236
The BaseFilter class is an abstract base class (ABC) for implementing kalman filters.
4337
It defines common attributes and methods used by kalman filter implementations.
4438
4539
Attributes
4640
----------
47-
mode : str or None
48-
The mode used for Pytensor compilation.
49-
5041
seq_names : list[str]
5142
A list of name representing time-varying statespace matrices. That is, inputs that will need to be
5243
provided to the `sequences` argument of `pytensor.scan`
@@ -56,7 +47,6 @@ def __init__(self, mode=None):
5647
to the `non_sequences` argument of `pytensor.scan`
5748
"""
5849

59-
self.mode: str = mode
6050
self.seq_names: list[str] = []
6151
self.non_seq_names: list[str] = []
6252

@@ -153,7 +143,6 @@ def build_graph(
153143
R,
154144
H,
155145
Q,
156-
mode=None,
157146
return_updates=False,
158147
missing_fill_value=None,
159148
cov_jitter=None,
@@ -166,9 +155,6 @@ def build_graph(
166155
data : TensorVariable
167156
Data to be filtered
168157
169-
mode : optional, str
170-
Pytensor compile mode, passed to pytensor.scan
171-
172158
return_updates: bool, default False
173159
Whether to return updates associated with the pytensor scan. Should only be requried to debug pruposes.
174160
@@ -199,7 +185,6 @@ def build_graph(
199185
if cov_jitter is None:
200186
cov_jitter = JITTER_DEFAULT
201187

202-
self.mode = mode
203188
self.missing_fill_value = missing_fill_value
204189
self.cov_jitter = cov_jitter
205190

@@ -227,7 +212,6 @@ def build_graph(
227212
outputs_info=[None, a0, None, None, P0, None, None],
228213
non_sequences=non_sequences,
229214
name="forward_kalman_pass",
230-
mode=get_mode(self.mode),
231215
strict=False,
232216
)
233217

@@ -800,7 +784,6 @@ def kalman_step(self, y, a, P, c, d, T, Z, R, H, Q):
800784
self._univariate_inner_filter_step,
801785
sequences=[y_masked, Z_masked, d, pt.diag(H_masked), nan_mask],
802786
outputs_info=[a, P, None, None, None],
803-
mode=get_mode(self.mode),
804787
name="univariate_inner_scan",
805788
)
806789

pymc_extras/statespace/filters/kalman_smoother.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytensor
22
import pytensor.tensor as pt
33

4-
from pytensor.compile import get_mode
54
from pytensor.tensor.nlinalg import matrix_dot
65

76
from pymc_extras.statespace.filters.utilities import (
@@ -18,8 +17,7 @@ class KalmanSmoother:
1817
1918
"""
2019

21-
def __init__(self, mode: str | None = None):
22-
self.mode = mode
20+
def __init__(self):
2321
self.cov_jitter = JITTER_DEFAULT
2422
self.seq_names = []
2523
self.non_seq_names = []
@@ -64,9 +62,8 @@ def unpack_args(self, args):
6462
return a, P, a_smooth, P_smooth, T, R, Q
6563

6664
def build_graph(
67-
self, T, R, Q, filtered_states, filtered_covariances, mode=None, cov_jitter=JITTER_DEFAULT
65+
self, T, R, Q, filtered_states, filtered_covariances, cov_jitter=JITTER_DEFAULT
6866
):
69-
self.mode = mode
7067
self.cov_jitter = cov_jitter
7168

7269
n, k = filtered_states.type.shape
@@ -88,7 +85,6 @@ def build_graph(
8885
non_sequences=non_sequences,
8986
go_backwards=True,
9087
name="kalman_smoother",
91-
mode=get_mode(self.mode),
9288
)
9389

9490
smoothed_states, smoothed_covariances = smoother_result

pymc_extras/statespace/models/SARIMAX.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class BayesianSARIMA(PyMCStateSpace):
158158
rho = pm.Beta("ar_params", alpha=5, beta=1, dims=ss_mod.param_dims["ar_params"])
159159
theta = pm.Normal("ma_params", mu=0.0, sigma=0.5, dims=ss_mod.param_dims["ma_params"])
160160
161-
ss_mod.build_statespace_graph(df, mode="JAX")
161+
ss_mod.build_statespace_graph(df)
162162
idata = pm.sample(nuts_sampler='numpyro')
163163
164164
References
@@ -366,17 +366,15 @@ def coords(self) -> dict[str, Sequence]:
366366

367367
return coords
368368

369-
def _stationary_initialization(self, mode=None):
369+
def _stationary_initialization(self):
370370
# Solve for matrix quadratic for P0
371371
T = self.ssm["transition"]
372372
R = self.ssm["selection"]
373373
Q = self.ssm["state_cov"]
374374
c = self.ssm["state_intercept"]
375375

376376
x0 = pt.linalg.solve(pt.identity_like(T) - T, c, assume_a="gen", check_finite=True)
377-
378-
method = "direct" if (self.k_states < 5) or (mode == "JAX") else "bilinear"
379-
P0 = solve_discrete_lyapunov(T, pt.linalg.matrix_dot(R, Q, R.T), method=method)
377+
P0 = solve_discrete_lyapunov(T, pt.linalg.matrix_dot(R, Q, R.T), method="bilinear")
380378

381379
return x0, P0
382380

pymc_extras/statespace/models/VARMAX.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class BayesianVARMAX(PyMCStateSpace):
135135
ar_params = pm.Normal("ar_params", mu=0, sigma=1, dims=ar_dims)
136136
state_cov = pm.Deterministic("state_cov", state_chol @ state_chol.T, dims=state_cov_dims)
137137
138-
bvar_mod.build_statespace_graph(data, mode="JAX")
138+
bvar_mod.build_statespace_graph(data)
139139
idata = pm.sample(nuts_sampler="numpyro")
140140
"""
141141

0 commit comments

Comments
 (0)