Skip to content

Commit b044c03

Browse files
committed
Circular dependency fix
1 parent 58bc697 commit b044c03

File tree

3 files changed

+4
-13
lines changed

3 files changed

+4
-13
lines changed

pymc_extras/model/marginal/distributions.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,6 @@
1919
from pymc_extras.distributions import DiscreteMarkovChain
2020

2121

22-
def get_support_axes(op) -> tuple[tuple[int, ...], ...]:
23-
if hasattr(op, "support_axes"):
24-
return op.support_axes
25-
else:
26-
# For vanilla RVs, the support axes are the last ndim_supp
27-
return (tuple(range(-op.ndim_supp, 0)),)
28-
29-
3022
class MarginalRV(OpFromGraph, MeasurableOp):
3123
"""Base class for Marginalized RVs"""
3224

@@ -99,6 +91,7 @@ def reduce_batch_dependent_logps(
9991
as well as transpose the remaining axis of dep1 logp before adding the two element-wise.
10092
10193
"""
94+
from pymc_extras.model.marginal.graph_analysis import get_support_axes
10295

10396
reduced_logps = []
10497
for dependent_op, dependent_logp, dependent_dims_connection in zip(

pymc_extras/statespace/models/structural.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,11 +1481,9 @@ def __init__(
14811481
k_endog=k_endog,
14821482
k_states=k_states,
14831483
k_posdef=k_posdef,
1484-
state_names=self.state_names,
14851484
measurement_error=False,
14861485
combine_hidden_states=True,
1487-
exog_names=[f"data_{name}"],
1488-
obs_state_idxs=np.ones(k_states),
1486+
obs_state_idxs=obs_state_idx,
14891487
)
14901488

14911489
def make_symbolic_graph(self) -> None:

tests/statespace/utilities/test_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother
1111
from pymc_extras.statespace.utils.constants import (
1212
JITTER_DEFAULT,
13-
LONG_MATRIX_NAMES,
13+
MATRIX_NAMES,
1414
MISSING_FILL,
1515
SHORT_NAME_TO_LONG,
1616
)
@@ -210,7 +210,7 @@ def delete_rvs_from_model(rv_names: list[str]) -> None:
210210

211211

212212
def unpack_statespace(ssm):
213-
return [ssm[SHORT_NAME_TO_LONG[x]] for x in LONG_MATRIX_NAMES]
213+
return [ssm[SHORT_NAME_TO_LONG[x]] for x in MATRIX_NAMES]
214214

215215

216216
def unpack_symbolic_matrices_with_params(mod, param_dict, data_dict=None, mode="FAST_COMPILE"):

0 commit comments

Comments
 (0)