File tree 3 files changed +4
-13
lines changed
tests/statespace/utilities
3 files changed +4
-13
lines changed Original file line number Diff line number Diff line change 19
19
from pymc_extras .distributions import DiscreteMarkovChain
20
20
21
21
22
- def get_support_axes (op ) -> tuple [tuple [int , ...], ...]:
23
- if hasattr (op , "support_axes" ):
24
- return op .support_axes
25
- else :
26
- # For vanilla RVs, the support axes are the last ndim_supp
27
- return (tuple (range (- op .ndim_supp , 0 )),)
28
-
29
-
30
22
class MarginalRV (OpFromGraph , MeasurableOp ):
31
23
"""Base class for Marginalized RVs"""
32
24
@@ -99,6 +91,7 @@ def reduce_batch_dependent_logps(
99
91
as well as transpose the remaining axis of dep1 logp before adding the two element-wise.
100
92
101
93
"""
94
+ from pymc_extras .model .marginal .graph_analysis import get_support_axes
102
95
103
96
reduced_logps = []
104
97
for dependent_op , dependent_logp , dependent_dims_connection in zip (
Original file line number Diff line number Diff line change @@ -1481,11 +1481,9 @@ def __init__(
1481
1481
k_endog = k_endog ,
1482
1482
k_states = k_states ,
1483
1483
k_posdef = k_posdef ,
1484
- state_names = self .state_names ,
1485
1484
measurement_error = False ,
1486
1485
combine_hidden_states = True ,
1487
- exog_names = [f"data_{ name } " ],
1488
- obs_state_idxs = np .ones (k_states ),
1486
+ obs_state_idxs = obs_state_idx ,
1489
1487
)
1490
1488
1491
1489
def make_symbolic_graph (self ) -> None :
Original file line number Diff line number Diff line change 10
10
from pymc_extras .statespace .filters .kalman_smoother import KalmanSmoother
11
11
from pymc_extras .statespace .utils .constants import (
12
12
JITTER_DEFAULT ,
13
- LONG_MATRIX_NAMES ,
13
+ MATRIX_NAMES ,
14
14
MISSING_FILL ,
15
15
SHORT_NAME_TO_LONG ,
16
16
)
@@ -210,7 +210,7 @@ def delete_rvs_from_model(rv_names: list[str]) -> None:
210
210
211
211
212
212
def unpack_statespace (ssm ):
213
- return [ssm [SHORT_NAME_TO_LONG [x ]] for x in LONG_MATRIX_NAMES ]
213
+ return [ssm [SHORT_NAME_TO_LONG [x ]] for x in MATRIX_NAMES ]
214
214
215
215
216
216
def unpack_symbolic_matrices_with_params (mod , param_dict , data_dict = None , mode = "FAST_COMPILE" ):
You can’t perform that action at this time.
0 commit comments