Skip to content

Commit 32bc557

Browse files
Depreciate duplicate functions in statespace (#308)
* Delete `pytensor_scipy.py` * Delete `block_diagonal`, use `pt.linalg.block_diag` * Relax tolerance in loglikelihood tests * Use updated pandas frequency string
1 parent 4ee57e5 commit 32bc557

File tree

8 files changed

+18
-213
lines changed

8 files changed

+18
-213
lines changed

pymc_experimental/statespace/filters/kalman_filter.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@
99
from pytensor.raise_op import Assert
1010
from pytensor.tensor import TensorVariable
1111
from pytensor.tensor.nlinalg import matrix_dot
12-
from pytensor.tensor.slinalg import solve_triangular
12+
from pytensor.tensor.slinalg import solve_discrete_are, solve_triangular
1313

1414
from pymc_experimental.statespace.filters.utilities import (
1515
quad_form_sym,
1616
split_vars_into_seq_and_nonseq,
1717
stabilize,
1818
)
1919
from pymc_experimental.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL
20-
from pymc_experimental.statespace.utils.pytensor_scipy import solve_discrete_are
2120

2221
MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
2322
PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"]

pymc_experimental/statespace/models/SARIMAX.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import pytensor.tensor as pt
5+
from pytensor.tensor.slinalg import solve_discrete_lyapunov
56

67
from pymc_experimental.statespace.core.statespace import PyMCStateSpace, floatX
78
from pymc_experimental.statespace.models.utilities import (
@@ -19,7 +20,6 @@
1920
SEASONAL_AR_PARAM_DIM,
2021
SEASONAL_MA_PARAM_DIM,
2122
)
22-
from pymc_experimental.statespace.utils.pytensor_scipy import solve_discrete_lyapunov
2323

2424

2525
def _verify_order(p, d, q, P, D, Q, S):

pymc_experimental/statespace/models/VARMAX.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pytensor
55
import pytensor.tensor as pt
6+
from pytensor.tensor.slinalg import solve_discrete_lyapunov
67

78
from pymc_experimental.statespace.core.statespace import PyMCStateSpace
89
from pymc_experimental.statespace.models.utilities import make_default_coords
@@ -16,7 +17,6 @@
1617
SHOCK_AUX_DIM,
1718
SHOCK_DIM,
1819
)
19-
from pymc_experimental.statespace.utils.pytensor_scipy import solve_discrete_lyapunov
2020

2121
floatX = pytensor.config.floatX
2222

pymc_experimental/statespace/models/structural.py

+5-22
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,6 @@ def _frequency_transition_block(s, j):
4444
return pt.stack([[pt.cos(lam), pt.sin(lam)], [-pt.sin(lam), pt.cos(lam)]]).squeeze()
4545

4646

47-
def block_diagonal(matrices: list[pt.matrix]):
48-
rows = [x.shape[0] for x in matrices]
49-
cols = [x.shape[1] for x in matrices]
50-
out = pt.zeros((sum(rows), sum(cols)))
51-
row_cursor = 0
52-
col_cursor = 0
53-
54-
for row, col, mat in zip(rows, cols, matrices):
55-
row_slice = slice(row_cursor, row_cursor + row)
56-
col_slice = slice(col_cursor, col_cursor + col)
57-
row_cursor += row
58-
col_cursor += col
59-
60-
out = pt.set_subtensor(out[row_slice, col_slice], mat)
61-
return out
62-
63-
6447
class StructuralTimeSeries(PyMCStateSpace):
6548
r"""
6649
Structural Time Series Model
@@ -527,7 +510,7 @@ def make_slice(name, x, o_x):
527510
initial_state = pt.concatenate(conform_time_varying_and_time_invariant_matrices(x0, o_x0))
528511
initial_state.name = x0.name
529512

530-
initial_state_cov = block_diagonal([P0, o_P0])
513+
initial_state_cov = pt.linalg.block_diag(P0, o_P0)
531514
initial_state_cov.name = P0.name
532515

533516
state_intercept = pt.concatenate(conform_time_varying_and_time_invariant_matrices(c, o_c))
@@ -536,19 +519,19 @@ def make_slice(name, x, o_x):
536519
obs_intercept = d + o_d
537520
obs_intercept.name = d.name
538521

539-
transition = block_diagonal([T, o_T])
522+
transition = pt.linalg.block_diag(T, o_T)
540523
transition.name = T.name
541524

542525
design = pt.concatenate(conform_time_varying_and_time_invariant_matrices(Z, o_Z), axis=-1)
543526
design.name = Z.name
544527

545-
selection = block_diagonal([R, o_R])
528+
selection = pt.linalg.block_diag(R, o_R)
546529
selection.name = R.name
547530

548531
obs_cov = H + o_H
549532
obs_cov.name = H.name
550533

551-
state_cov = block_diagonal([Q, o_Q])
534+
state_cov = pt.linalg.block_diag(Q, o_Q)
552535
state_cov.name = Q.name
553536

554537
new_ssm = PytensorRepresentation(
@@ -1326,7 +1309,7 @@ def make_symbolic_graph(self) -> None:
13261309
self.ssm["initial_state", init_state_idx] = init_state
13271310

13281311
T_mats = [_frequency_transition_block(self.season_length, j + 1) for j in range(self.n)]
1329-
T = block_diagonal(T_mats)
1312+
T = pt.linalg.block_diag(*T_mats)
13301313
self.ssm["transition", :, :] = T
13311314

13321315
if self.innovations:

pymc_experimental/statespace/utils/pytensor_scipy.py

-85
This file was deleted.

pymc_experimental/tests/statespace/test_distributions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727
# TODO: These are pretty loose because of all the stabilizing of covariance matrices that is done inside the kalman
2828
# filters. When that is improved, this should be tightened.
29-
ATOL = 1e-6 if floatX.endswith("64") else 1e-4
30-
RTOL = 1e-6 if floatX.endswith("64") else 1e-4
29+
ATOL = 1e-5 if floatX.endswith("64") else 1e-4
30+
RTOL = 1e-5 if floatX.endswith("64") else 1e-4
3131

3232
filter_names = [
3333
"standard",

pymc_experimental/tests/statespace/test_pytensor_scipy.py

-99
This file was deleted.

pymc_experimental/tests/statespace/utilities/test_helpers.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,15 @@
1919

2020

2121
def load_nile_test_data():
22+
from importlib.metadata import version
23+
2224
nile = pd.read_csv("pymc_experimental/tests/statespace/test_data/nile.csv", dtype={"x": floatX})
23-
nile.index = pd.date_range(start="1871-01-01", end="1970-01-01", freq="AS-Jan")
25+
major, minor, rev = map(int, version("pandas").split("."))
26+
if major >= 2 and minor >= 2 and rev >= 0:
27+
freq_str = "YS-JAN"
28+
else:
29+
freq_str = "AS-JAN"
30+
nile.index = pd.date_range(start="1871-01-01", end="1970-01-01", freq=freq_str)
2431
nile.rename(columns={"x": "height"}, inplace=True)
2532
nile = (nile - nile.mean()) / nile.std()
2633
nile = nile.astype(floatX)

0 commit comments

Comments
 (0)