Skip to content

Commit a8a8a6b

Browse files
committed
Revert questionable changes
1 parent 2f245d7 commit a8a8a6b

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
FILTER_OUTPUT_DIMS,
3535
FILTER_OUTPUT_TYPES,
3636
JITTER_DEFAULT,
37-
LONG_MATRIX_NAMES,
3837
MATRIX_DIMS,
38+
MATRIX_NAMES,
3939
OBS_STATE_DIM,
4040
SHOCK_DIM,
4141
SHORT_NAME_TO_LONG,
@@ -750,7 +750,7 @@ def _register_matrices_with_pymc_model(self) -> list[pt.TensorVariable]:
750750
matrices = self.unpack_statespace()
751751

752752
registered_matrices = []
753-
for i, (matrix, name) in enumerate(zip(matrices, LONG_MATRIX_NAMES)):
753+
for i, (matrix, name) in enumerate(zip(matrices, MATRIX_NAMES)):
754754
time_varying_ndim = 2 if name in VECTOR_VALUED else 3
755755
if not getattr(pm_mod, name, None):
756756
shape, dims = self._get_matrix_shape_and_dims(name)
@@ -1471,7 +1471,7 @@ def sample_statespace_matrices(
14711471
_verify_group(group)
14721472

14731473
if matrix_names is None:
1474-
matrix_names = LONG_MATRIX_NAMES
1474+
matrix_names = MATRIX_NAMES
14751475
elif isinstance(matrix_names, str):
14761476
matrix_names = [matrix_names]
14771477

@@ -1484,7 +1484,7 @@ def sample_statespace_matrices(
14841484

14851485
self._insert_data_variables()
14861486
matrices = self.unpack_statespace()
1487-
for short_name, matrix in zip(LONG_MATRIX_NAMES, matrices):
1487+
for short_name, matrix in zip(MATRIX_NAMES, matrices):
14881488
long_name = SHORT_NAME_TO_LONG[short_name]
14891489
if (long_name in matrix_names) or (short_name in matrix_names):
14901490
name = long_name if long_name in matrix_names else short_name
@@ -2038,10 +2038,7 @@ def forecast(
20382038
}
20392039

20402040
matrices = graph_replace(matrices, replace=sub_dict, strict=True)
2041-
[
2042-
setattr(matrix, "name", name)
2043-
for name, matrix in zip(LONG_MATRIX_NAMES[2:], matrices)
2044-
]
2041+
[setattr(matrix, "name", name) for name, matrix in zip(MATRIX_NAMES[2:], matrices)]
20452042

20462043
_ = LinearGaussianStateSpace(
20472044
"forecast",

0 commit comments

Comments
 (0)