Skip to content

Commit 27a8752

Browse files
Remove mode argument from everywhere
1 parent 9589f28 commit 27a8752

File tree

11 files changed

+19
-82
lines changed

11 files changed

+19
-82
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from pymc.model.transform.optimization import freeze_dims_and_data
1515
from pymc.util import RandomState
1616
from pytensor import Variable, graph_replace
17-
from pytensor.compile import get_mode
1817
from rich.box import SIMPLE_HEAD
1918
from rich.console import Console
2019
from rich.table import Table
@@ -222,7 +221,6 @@ def __init__(
222221
verbose: bool = True,
223222
measurement_error: bool = False,
224223
):
225-
self._fit_mode: str | None = None
226224
self._fit_coords: dict[str, Sequence[str]] | None = None
227225
self._fit_dims: dict[str, Sequence[str]] | None = None
228226
self._fit_data: pt.TensorVariable | None = None
@@ -819,7 +817,6 @@ def build_statespace_graph(
819817
self,
820818
data: np.ndarray | pd.DataFrame | pt.TensorVariable,
821819
register_data: bool = True,
822-
mode: str | None = None,
823820
missing_fill_value: float | None = None,
824821
cov_jitter: float | None = JITTER_DEFAULT,
825822
mvn_method: Literal["cholesky", "eigh", "svd"] = "svd",
@@ -898,7 +895,6 @@ def build_statespace_graph(
898895
filter_outputs = self.kalman_filter.build_graph(
899896
pt.as_tensor_variable(data),
900897
*self.unpack_statespace(),
901-
mode=mode,
902898
missing_fill_value=missing_fill_value,
903899
cov_jitter=cov_jitter,
904900
)
@@ -909,7 +905,7 @@ def build_statespace_graph(
909905
filtered_covariances, predicted_covariances, observed_covariances = covs
910906
if save_kalman_filter_outputs_in_idata:
911907
smooth_states, smooth_covariances = self._build_smoother_graph(
912-
filtered_states, filtered_covariances, self.unpack_statespace(), mode=mode
908+
filtered_states, filtered_covariances, self.unpack_statespace()
913909
)
914910
all_kf_outputs = [*states, smooth_states, *covs, smooth_covariances]
915911
self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs)
@@ -929,7 +925,6 @@ def build_statespace_graph(
929925

930926
self._fit_coords = pm_mod.coords.copy()
931927
self._fit_dims = pm_mod.named_vars_to_dims.copy()
932-
self._fit_mode = mode
933928

934929
def _build_smoother_graph(
935930
self,
@@ -974,7 +969,7 @@ def _build_smoother_graph(
974969
*_, T, Z, R, H, Q = matrices
975970

976971
smooth_states, smooth_covariances = self.kalman_smoother.build_graph(
977-
T, R, Q, filtered_states, filtered_covariances, mode=mode, cov_jitter=cov_jitter
972+
T, R, Q, filtered_states, filtered_covariances, cov_jitter=cov_jitter
978973
)
979974
smooth_states.name = "smooth_states"
980975
smooth_covariances.name = "smooth_covariances"
@@ -1092,7 +1087,6 @@ def _kalman_filter_outputs_from_dummy_graph(
10921087
R,
10931088
H,
10941089
Q,
1095-
mode=self._fit_mode,
10961090
)
10971091

10981092
filter_outputs.pop(-1)
@@ -1102,7 +1096,7 @@ def _kalman_filter_outputs_from_dummy_graph(
11021096
filtered_covariances, predicted_covariances, _ = covariances
11031097

11041098
[smoothed_states, smoothed_covariances] = self.kalman_smoother.build_graph(
1105-
T, R, Q, filtered_states, filtered_covariances, mode=self._fit_mode
1099+
T, R, Q, filtered_states, filtered_covariances
11061100
)
11071101

11081102
grouped_outputs = [
@@ -1229,7 +1223,6 @@ def _sample_conditional(
12291223
for name in FILTER_OUTPUT_TYPES
12301224
for suffix in ["", "_observed"]
12311225
],
1232-
compile_kwargs={"mode": get_mode(self._fit_mode)},
12331226
random_seed=random_seed,
12341227
**kwargs,
12351228
)
@@ -1338,7 +1331,6 @@ def _sample_unconditional(
13381331
*matrices,
13391332
steps=steps,
13401333
dims=dims,
1341-
mode=self._fit_mode,
13421334
method=mvn_method,
13431335
sequence_names=self.kalman_filter.seq_names,
13441336
k_endog=self.k_endog,
@@ -1354,7 +1346,6 @@ def _sample_unconditional(
13541346
idata_unconditional = pm.sample_posterior_predictive(
13551347
group_idata,
13561348
var_names=[f"{group}_latent", f"{group}_observed"],
1357-
compile_kwargs={"mode": self._fit_mode},
13581349
random_seed=random_seed,
13591350
**kwargs,
13601351
)
@@ -1636,7 +1627,6 @@ def sample_statespace_matrices(
16361627
matrix_idata = pm.sample_posterior_predictive(
16371628
idata if group == "posterior" else idata.prior,
16381629
var_names=matrix_names,
1639-
compile_kwargs={"mode": self._fit_mode},
16401630
extend_inferencedata=False,
16411631
)
16421632

@@ -2192,7 +2182,6 @@ def forecast(
21922182
*matrices,
21932183
steps=len(forecast_index),
21942184
dims=dims,
2195-
mode=self._fit_mode,
21962185
sequence_names=self.kalman_filter.seq_names,
21972186
k_endog=self.k_endog,
21982187
append_x0=False,
@@ -2208,7 +2197,6 @@ def forecast(
22082197
idata_forecast = pm.sample_posterior_predictive(
22092198
idata,
22102199
var_names=["forecast_latent", "forecast_observed"],
2211-
compile_kwargs={"mode": self._fit_mode},
22122200
random_seed=random_seed,
22132201
**kwargs,
22142202
)
@@ -2368,28 +2356,13 @@ def irf_step(shock, x, c, T, R):
23682356
non_sequences=[c, T, R],
23692357
n_steps=n_steps,
23702358
strict=True,
2371-
mode=self._fit_mode,
23722359
)
23732360

23742361
pm.Deterministic("irf", irf, dims=[TIME_DIM, ALL_STATE_DIM])
23752362

2376-
compile_kwargs = kwargs.get("compile_kwargs", {})
2377-
if "mode" not in compile_kwargs.keys():
2378-
compile_kwargs = {"mode": self._fit_mode}
2379-
else:
2380-
mode = compile_kwargs.get("mode")
2381-
if mode is not None and mode != self._fit_mode:
2382-
raise ValueError(
2383-
f"User provided compile mode ({mode}) does not match the compile mode used to "
2384-
f"construct the model ({self._fit_mode})."
2385-
)
2386-
2387-
compile_kwargs.update({"mode": self._fit_mode})
2388-
23892363
irf_idata = pm.sample_posterior_predictive(
23902364
idata,
23912365
var_names=["irf"],
2392-
compile_kwargs=compile_kwargs,
23932366
random_seed=random_seed,
23942367
**kwargs,
23952368
)

pymc_extras/statespace/filters/distributions.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def __new__(
6969
H,
7070
Q,
7171
steps=None,
72-
mode=None,
7372
sequence_names=None,
7473
append_x0=True,
7574
method="svd",
@@ -98,7 +97,6 @@ def __new__(
9897
H,
9998
Q,
10099
steps=steps,
101-
mode=mode,
102100
sequence_names=sequence_names,
103101
append_x0=append_x0,
104102
method=method,
@@ -118,7 +116,6 @@ def dist(
118116
H,
119117
Q,
120118
steps=None,
121-
mode=None,
122119
sequence_names=None,
123120
append_x0=True,
124121
method="svd",
@@ -135,7 +132,6 @@ def dist(
135132

136133
return super().dist(
137134
[a0, P0, c, d, T, Z, R, H, Q, steps],
138-
mode=mode,
139135
sequence_names=sequence_names,
140136
append_x0=append_x0,
141137
method=method,
@@ -156,7 +152,6 @@ def rv_op(
156152
Q,
157153
steps,
158154
size=None,
159-
mode=None,
160155
sequence_names=None,
161156
append_x0=True,
162157
method="svd",
@@ -240,7 +235,6 @@ def step_fn(*args):
240235
sequences=None if len(sequences) == 0 else sequences,
241236
non_sequences=[*non_sequences, rng],
242237
n_steps=steps,
243-
mode=mode,
244238
strict=True,
245239
)
246240

@@ -284,7 +278,6 @@ def __new__(
284278
steps,
285279
k_endog=None,
286280
sequence_names=None,
287-
mode=None,
288281
append_x0=True,
289282
method="svd",
290283
**kwargs,
@@ -313,7 +306,6 @@ def __new__(
313306
H,
314307
Q,
315308
steps=steps,
316-
mode=mode,
317309
sequence_names=sequence_names,
318310
append_x0=append_x0,
319311
method=method,

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pytensor.tensor as pt
66

77
from pymc.pytensorf import constant_fold
8-
from pytensor.compile.mode import get_mode
98
from pytensor.graph.basic import Variable
109
from pytensor.raise_op import Assert
1110
from pytensor.tensor import TensorVariable
@@ -28,25 +27,17 @@
2827

2928

3029
class BaseFilter(ABC):
31-
def __init__(self, mode=None):
30+
def __init__(self):
3231
"""
3332
Kalman Filter.
3433
35-
Parameters
36-
----------
37-
mode : str, optional
38-
The mode used for Pytensor compilation. Defaults to None.
39-
4034
Notes
4135
-----
4236
The BaseFilter class is an abstract base class (ABC) for implementing kalman filters.
4337
It defines common attributes and methods used by kalman filter implementations.
4438
4539
Attributes
4640
----------
47-
mode : str or None
48-
The mode used for Pytensor compilation.
49-
5041
seq_names : list[str]
5142
A list of name representing time-varying statespace matrices. That is, inputs that will need to be
5243
provided to the `sequences` argument of `pytensor.scan`
@@ -56,7 +47,6 @@ def __init__(self, mode=None):
5647
to the `non_sequences` argument of `pytensor.scan`
5748
"""
5849

59-
self.mode: str = mode
6050
self.seq_names: list[str] = []
6151
self.non_seq_names: list[str] = []
6252

@@ -153,7 +143,6 @@ def build_graph(
153143
R,
154144
H,
155145
Q,
156-
mode=None,
157146
return_updates=False,
158147
missing_fill_value=None,
159148
cov_jitter=None,
@@ -166,9 +155,6 @@ def build_graph(
166155
data : TensorVariable
167156
Data to be filtered
168157
169-
mode : optional, str
170-
Pytensor compile mode, passed to pytensor.scan
171-
172158
return_updates: bool, default False
173159
Whether to return updates associated with the pytensor scan. Should only be requried to debug pruposes.
174160
@@ -199,7 +185,6 @@ def build_graph(
199185
if cov_jitter is None:
200186
cov_jitter = JITTER_DEFAULT
201187

202-
self.mode = mode
203188
self.missing_fill_value = missing_fill_value
204189
self.cov_jitter = cov_jitter
205190

@@ -227,7 +212,6 @@ def build_graph(
227212
outputs_info=[None, a0, None, None, P0, None, None],
228213
non_sequences=non_sequences,
229214
name="forward_kalman_pass",
230-
mode=get_mode(self.mode),
231215
strict=False,
232216
)
233217

@@ -800,7 +784,6 @@ def kalman_step(self, y, a, P, c, d, T, Z, R, H, Q):
800784
self._univariate_inner_filter_step,
801785
sequences=[y_masked, Z_masked, d, pt.diag(H_masked), nan_mask],
802786
outputs_info=[a, P, None, None, None],
803-
mode=get_mode(self.mode),
804787
name="univariate_inner_scan",
805788
)
806789

pymc_extras/statespace/filters/kalman_smoother.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytensor
22
import pytensor.tensor as pt
33

4-
from pytensor.compile import get_mode
54
from pytensor.tensor.nlinalg import matrix_dot
65

76
from pymc_extras.statespace.filters.utilities import (
@@ -18,8 +17,7 @@ class KalmanSmoother:
1817
1918
"""
2019

21-
def __init__(self, mode: str | None = None):
22-
self.mode = mode
20+
def __init__(self):
2321
self.cov_jitter = JITTER_DEFAULT
2422
self.seq_names = []
2523
self.non_seq_names = []
@@ -64,9 +62,8 @@ def unpack_args(self, args):
6462
return a, P, a_smooth, P_smooth, T, R, Q
6563

6664
def build_graph(
67-
self, T, R, Q, filtered_states, filtered_covariances, mode=None, cov_jitter=JITTER_DEFAULT
65+
self, T, R, Q, filtered_states, filtered_covariances, cov_jitter=JITTER_DEFAULT
6866
):
69-
self.mode = mode
7067
self.cov_jitter = cov_jitter
7168

7269
n, k = filtered_states.type.shape
@@ -88,7 +85,6 @@ def build_graph(
8885
non_sequences=non_sequences,
8986
go_backwards=True,
9087
name="kalman_smoother",
91-
mode=get_mode(self.mode),
9288
)
9389

9490
smoothed_states, smoothed_covariances = smoother_result

pymc_extras/statespace/models/SARIMAX.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class BayesianSARIMA(PyMCStateSpace):
158158
rho = pm.Beta("ar_params", alpha=5, beta=1, dims=ss_mod.param_dims["ar_params"])
159159
theta = pm.Normal("ma_params", mu=0.0, sigma=0.5, dims=ss_mod.param_dims["ma_params"])
160160
161-
ss_mod.build_statespace_graph(df, mode="JAX")
161+
ss_mod.build_statespace_graph(df)
162162
idata = pm.sample(nuts_sampler='numpyro')
163163
164164
References
@@ -366,17 +366,15 @@ def coords(self) -> dict[str, Sequence]:
366366

367367
return coords
368368

369-
def _stationary_initialization(self, mode=None):
369+
def _stationary_initialization(self):
370370
# Solve for matrix quadratic for P0
371371
T = self.ssm["transition"]
372372
R = self.ssm["selection"]
373373
Q = self.ssm["state_cov"]
374374
c = self.ssm["state_intercept"]
375375

376376
x0 = pt.linalg.solve(pt.identity_like(T) - T, c, assume_a="gen", check_finite=True)
377-
378-
method = "direct" if (self.k_states < 5) or (mode == "JAX") else "bilinear"
379-
P0 = solve_discrete_lyapunov(T, pt.linalg.matrix_dot(R, Q, R.T), method=method)
377+
P0 = solve_discrete_lyapunov(T, pt.linalg.matrix_dot(R, Q, R.T), method="bilinear")
380378

381379
return x0, P0
382380

pymc_extras/statespace/models/VARMAX.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class BayesianVARMAX(PyMCStateSpace):
135135
ar_params = pm.Normal("ar_params", mu=0, sigma=1, dims=ar_dims)
136136
state_cov = pm.Deterministic("state_cov", state_chol @ state_chol.T, dims=state_cov_dims)
137137
138-
bvar_mod.build_statespace_graph(data, mode="JAX")
138+
bvar_mod.build_statespace_graph(data)
139139
idata = pm.sample(nuts_sampler="numpyro")
140140
"""
141141

0 commit comments

Comments
 (0)