34
34
FILTER_OUTPUT_DIMS ,
35
35
FILTER_OUTPUT_TYPES ,
36
36
JITTER_DEFAULT ,
37
- LONG_MATRIX_NAMES ,
38
37
MATRIX_DIMS ,
38
+ MATRIX_NAMES ,
39
39
OBS_STATE_DIM ,
40
40
SHOCK_DIM ,
41
41
SHORT_NAME_TO_LONG ,
@@ -750,7 +750,7 @@ def _register_matrices_with_pymc_model(self) -> list[pt.TensorVariable]:
750
750
matrices = self .unpack_statespace ()
751
751
752
752
registered_matrices = []
753
- for i , (matrix , name ) in enumerate (zip (matrices , LONG_MATRIX_NAMES )):
753
+ for i , (matrix , name ) in enumerate (zip (matrices , MATRIX_NAMES )):
754
754
time_varying_ndim = 2 if name in VECTOR_VALUED else 3
755
755
if not getattr (pm_mod , name , None ):
756
756
shape , dims = self ._get_matrix_shape_and_dims (name )
@@ -1471,7 +1471,7 @@ def sample_statespace_matrices(
1471
1471
_verify_group (group )
1472
1472
1473
1473
if matrix_names is None :
1474
- matrix_names = LONG_MATRIX_NAMES
1474
+ matrix_names = MATRIX_NAMES
1475
1475
elif isinstance (matrix_names , str ):
1476
1476
matrix_names = [matrix_names ]
1477
1477
@@ -1484,7 +1484,7 @@ def sample_statespace_matrices(
1484
1484
1485
1485
self ._insert_data_variables ()
1486
1486
matrices = self .unpack_statespace ()
1487
- for short_name , matrix in zip (LONG_MATRIX_NAMES , matrices ):
1487
+ for short_name , matrix in zip (MATRIX_NAMES , matrices ):
1488
1488
long_name = SHORT_NAME_TO_LONG [short_name ]
1489
1489
if (long_name in matrix_names ) or (short_name in matrix_names ):
1490
1490
name = long_name if long_name in matrix_names else short_name
@@ -2038,10 +2038,7 @@ def forecast(
2038
2038
}
2039
2039
2040
2040
matrices = graph_replace (matrices , replace = sub_dict , strict = True )
2041
- [
2042
- setattr (matrix , "name" , name )
2043
- for name , matrix in zip (LONG_MATRIX_NAMES [2 :], matrices )
2044
- ]
2041
+ [setattr (matrix , "name" , name ) for name , matrix in zip (MATRIX_NAMES [2 :], matrices )]
2045
2042
2046
2043
_ = LinearGaussianStateSpace (
2047
2044
"forecast" ,
0 commit comments