Skip to content

Commit 2da9ea1

Browse files
committed
Revert functional changes during rename
1 parent 088a477 commit 2da9ea1

15 files changed

+32
-53
lines changed

docs/statespace/filters.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ Kalman Filters
99
StandardFilter
1010
UnivariateFilter
1111
KalmanSmoother
12+
SquareRootFilter
1213
LinearGaussianStateSpace

pymc_extras/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,3 @@
2727
if len(_log.handlers) == 0:
2828
handler = logging.StreamHandler()
2929
_log.addHandler(handler)
30-
31-
__all__ = ["fit", "MarginalModel", "marginalize", "as_model"]

pymc_extras/model/marginal/graph_analysis.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from pymc import SymbolicRandomVariable
77
from pytensor.compile import SharedVariable
8-
from pytensor.compile.builders import OpFromGraph
98
from pytensor.graph import Constant, Variable, ancestors
109
from pytensor.graph.basic import io_toposort
1110
from pytensor.tensor import TensorType, TensorVariable
@@ -17,6 +16,8 @@
1716
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list
1817
from pytensor.tensor.type_other import NoneTypeT
1918

19+
from pymc_extras.model.marginal.distributions import MarginalRV
20+
2021

2122
def static_shape_ancestors(vars):
2223
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
@@ -62,7 +63,7 @@ def find_conditional_dependent_rvs(dependable_rv, all_rvs):
6263

6364

6465
def get_support_axes(op) -> tuple[tuple[int, ...], ...]:
65-
if hasattr(op, "support_axes"):
66+
if isinstance(op, MarginalRV):
6667
return op.support_axes
6768
else:
6869
# For vanilla RVs, the support axes are the last ndim_supp
@@ -145,7 +146,7 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
145146
output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)
146147
var_dims[node.outputs[0]] = output_dims
147148

148-
elif (isinstance(node.op, OpFromGraph) and hasattr(node.op, "support_axes")) or (
149+
elif isinstance(node.op, MarginalRV) or (
149150
isinstance(node.op, SymbolicRandomVariable) and node.op.extended_signature is None
150151
):
151152
# MarginalRV and SymbolicRandomVariables without signature are a wild-card,
@@ -159,7 +160,7 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
159160
)
160161

161162
support_axes = iter(get_support_axes(op))
162-
if hasattr(op, "support_axes"):
163+
if isinstance(op, MarginalRV):
163164
# The first output is the marginalized variable for which we don't compute support axes
164165
support_axes = itertools.chain(((),), support_axes)
165166
for i, (out, inner_out) in enumerate(zip(node.outputs, inner_outputs)):

pymc_extras/statespace/core/statespace.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pymc_extras.statespace.core.representation import PytensorRepresentation
2020
from pymc_extras.statespace.filters import (
2121
KalmanSmoother,
22+
SquareRootFilter,
2223
StandardFilter,
2324
UnivariateFilter,
2425
)
@@ -50,6 +51,7 @@
5051
FILTER_FACTORY = {
5152
"standard": StandardFilter,
5253
"univariate": UnivariateFilter,
54+
"cholesky": SquareRootFilter,
5355
}
5456

5557

pymc_extras/statespace/filters/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace
22
from pymc_extras.statespace.filters.kalman_filter import (
3+
SquareRootFilter,
34
StandardFilter,
45
UnivariateFilter,
56
)
@@ -9,5 +10,6 @@
910
"StandardFilter",
1011
"UnivariateFilter",
1112
"KalmanSmoother",
13+
"SquareRootFilter",
1214
"LinearGaussianStateSpace",
1315
]

pymc_extras/statespace/models/utilities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,8 @@ def make_SARIMA_transition_matrix(
233233
0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \end{bmatrix}
234234
235235
When ARIMA differences and seasonal differences are mixed, the seasonal differences will be written in terms of the
236-
highest ARIMA difference order, and recovery of the level state will require the use of all the ARIMA
237-
differences, as well as the seasonal differences. In addition, the seasonal differences are needed to back out the ARIMA
236+
highest ARIMA difference order, and recovery of the level state will require the use of all the ARIMA differences,
237+
as well as the seasonal differences. In addition, the seasonal differences are needed to back out the ARIMA
238238
differences from :math:`x_t^\star`. Here is the differencing block for a SARIMA(0,2,0)x(0,2,0,4) -- the identites
239239
of the states is left an exercise for the motivated reader:
240240

tests/statespace/test_coord_assignment.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,12 @@
1111
from pymc_extras.statespace.utils.constants import (
1212
FILTER_OUTPUT_DIMS,
1313
FILTER_OUTPUT_NAMES,
14-
JITTER_DEFAULT,
15-
LONG_MATRIX_NAMES,
16-
MISSING_FILL,
17-
SHORT_NAME_TO_LONG,
1814
SMOOTHER_OUTPUT_NAMES,
1915
TIME_DIM,
2016
)
2117
from pymc_extras.statespace.utils.data_tools import (
2218
NO_FREQ_INFO_WARNING,
2319
NO_TIME_INDEX_WARNING,
24-
register_data_with_pymc,
2520
)
2621
from tests.statespace.utilities.test_helpers import load_nile_test_data
2722

tests/statespace/test_distributions.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,12 @@
1010
from pymc_extras.statespace import structural
1111
from pymc_extras.statespace.filters.distributions import (
1212
LinearGaussianStateSpace,
13-
LinearGaussianStateSpaceRV,
1413
SequenceMvNormal,
1514
_LinearGaussianStateSpace,
1615
)
1716
from pymc_extras.statespace.utils.constants import (
1817
ALL_STATE_DIM,
19-
JITTER_DEFAULT,
20-
LONG_MATRIX_NAMES,
21-
MISSING_FILL,
2218
OBS_STATE_DIM,
23-
SHORT_NAME_TO_LONG,
2419
TIME_DIM,
2520
)
2621
from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import
@@ -41,7 +36,7 @@
4136

4237
filter_names = [
4338
"standard",
44-
# "cholesky",
39+
"cholesky",
4540
"univariate",
4641
]
4742

tests/statespace/test_kalman_filter.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from pymc_extras.statespace.filters import (
99
KalmanSmoother,
10+
SquareRootFilter,
1011
StandardFilter,
1112
UnivariateFilter,
1213
)
@@ -30,17 +31,18 @@
3031
RTOL = 1e-6 if floatX.endswith("64") else 1e-3
3132

3233
standard_inout = initialize_filter(StandardFilter())
33-
# cholesky_inout = initialize_filter(CholeskyFilter())
34+
cholesky_inout = initialize_filter(SquareRootFilter())
3435
univariate_inout = initialize_filter(UnivariateFilter())
3536

3637
f_standard = pytensor.function(*standard_inout, on_unused_input="ignore")
37-
# f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore")
38+
f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore")
3839
f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore")
3940

40-
filter_funcs = [f_standard, f_univariate]
41+
filter_funcs = [f_standard, f_cholesky, f_univariate]
4142

4243
filter_names = [
4344
"StandardFilter",
45+
"CholeskyFilter",
4446
"UnivariateFilter",
4547
]
4648

@@ -229,8 +231,8 @@ def test_last_smoother_is_last_filtered(filter_func, output_idx, rng):
229231
@pytest.mark.skipif(floatX == "float32", reason="Tests are too sensitive for float32")
230232
def test_filters_match_statsmodel_output(filter_func, filter_name, n_missing, rng):
231233
fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing)
232-
# if filter_name == "CholeskyFilter":
233-
# P0 = np.linalg.cholesky(P0)
234+
if filter_name == "CholeskyFilter":
235+
P0 = np.linalg.cholesky(P0)
234236
inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
235237
outputs = filter_func(*inputs)
236238

@@ -278,8 +280,8 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob
278280
pytest.skip("Univariate filter not stable at half precision without measurement error")
279281

280282
fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing)
281-
# if filter_name == "CholeskyFilter":
282-
# P0 = np.linalg.cholesky(P0)
283+
if filter_name == "CholeskyFilter":
284+
P0 = np.linalg.cholesky(P0)
283285

284286
H *= int(obs_noise)
285287
inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
@@ -301,8 +303,8 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob
301303

302304
@pytest.mark.parametrize(
303305
"filter",
304-
[StandardFilter],
305-
ids=["standard"],
306+
[StandardFilter, SquareRootFilter],
307+
ids=["standard", "cholesky"],
306308
)
307309
def test_kalman_filter_jax(filter):
308310
pytest.importorskip("jax")

tests/statespace/test_statespace.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,8 @@
1414
from pymc_extras.statespace.models.utilities import make_default_coords
1515
from pymc_extras.statespace.utils.constants import (
1616
FILTER_OUTPUT_NAMES,
17-
JITTER_DEFAULT,
18-
LONG_MATRIX_NAMES,
1917
MATRIX_NAMES,
20-
MISSING_FILL,
21-
NEVER_TIME_VARYING,
22-
SHORT_NAME_TO_LONG,
2318
SMOOTHER_OUTPUT_NAMES,
24-
VECTOR_VALUED,
2519
)
2620
from tests.statespace.utilities.shared_fixtures import (
2721
rng,

tests/statespace/test_statespace_JAX.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@
1010

1111
from pymc_extras.statespace.utils.constants import (
1212
FILTER_OUTPUT_NAMES,
13-
JITTER_DEFAULT,
14-
LONG_MATRIX_NAMES,
1513
MATRIX_NAMES,
16-
MISSING_FILL,
17-
SHORT_NAME_TO_LONG,
1814
SMOOTHER_OUTPUT_NAMES,
1915
)
2016
from tests.statespace.test_statespace import ( # pylint: disable=unused-import

tests/statespace/test_structural.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
ALL_STATE_AUX_DIM,
2121
ALL_STATE_DIM,
2222
AR_PARAM_DIM,
23-
JITTER_DEFAULT,
24-
LONG_MATRIX_NAMES,
25-
MISSING_FILL,
2623
OBS_STATE_AUX_DIM,
2724
OBS_STATE_DIM,
2825
SHOCK_AUX_DIM,

tests/statespace/utilities/test_helpers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99

1010
from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother
1111
from pymc_extras.statespace.utils.constants import (
12-
JITTER_DEFAULT,
1312
MATRIX_NAMES,
14-
MISSING_FILL,
1513
SHORT_NAME_TO_LONG,
1614
)
1715
from tests.statespace.utilities.statsmodel_local_level import LocalLinearTrend

tests/test_blackjax_smc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from numpy import dtype
2121
from xarray.core.utils import Frozen
2222

23+
jax = pytest.importorskip("jax")
24+
pytest.importorskip("blackjax")
25+
2326
from pymc_extras.inference.smc.sampling import (
2427
arviz_from_particles,
2528
blackjax_particles_from_pymc_population,
@@ -28,9 +31,6 @@
2831
sample_smc_blackjax,
2932
)
3033

31-
jax = pytest.importorskip("jax")
32-
pytest.importorskip("blackjax")
33-
3434

3535
def two_gaussians_model():
3636
n = 4

tests/test_find_map.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
from typing import Literal
2-
31
import numpy as np
42
import pymc as pm
53
import pytensor.tensor as pt
64
import pytest
75

8-
from pymc_extras.inference.find_map import find_MAP, scipy_optimize_funcs_from_loss
6+
from pymc_extras.inference.find_map import (
7+
GradientBackend,
8+
find_MAP,
9+
scipy_optimize_funcs_from_loss,
10+
)
911

1012
pytest.importorskip("jax")
1113

@@ -16,10 +18,6 @@ def rng():
1618
return np.random.default_rng(seed)
1719

1820

19-
# Define GradientBackend type alias
20-
GradientBackend = Literal["jax", "pytensor"]
21-
22-
2321
@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
2422
def test_jax_functions_from_graph(gradient_backend: GradientBackend):
2523
x = pt.tensor("x", shape=(2,))

0 commit comments

Comments
 (0)