Skip to content

Commit 667596a

Browse files
Delete block_diagonal, use pt.linalg.block_diag
1 parent 660da7c commit 667596a

File tree

1 file changed

+5
-22
lines changed

1 file changed

+5
-22
lines changed

pymc_experimental/statespace/models/structural.py

+5-22
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,6 @@ def _frequency_transition_block(s, j):
4444
return pt.stack([[pt.cos(lam), pt.sin(lam)], [-pt.sin(lam), pt.cos(lam)]]).squeeze()
4545

4646

47-
def block_diagonal(matrices: list[pt.matrix]):
48-
rows = [x.shape[0] for x in matrices]
49-
cols = [x.shape[1] for x in matrices]
50-
out = pt.zeros((sum(rows), sum(cols)))
51-
row_cursor = 0
52-
col_cursor = 0
53-
54-
for row, col, mat in zip(rows, cols, matrices):
55-
row_slice = slice(row_cursor, row_cursor + row)
56-
col_slice = slice(col_cursor, col_cursor + col)
57-
row_cursor += row
58-
col_cursor += col
59-
60-
out = pt.set_subtensor(out[row_slice, col_slice], mat)
61-
return out
62-
63-
6447
class StructuralTimeSeries(PyMCStateSpace):
6548
r"""
6649
Structural Time Series Model
@@ -527,7 +510,7 @@ def make_slice(name, x, o_x):
527510
initial_state = pt.concatenate(conform_time_varying_and_time_invariant_matrices(x0, o_x0))
528511
initial_state.name = x0.name
529512

530-
initial_state_cov = block_diagonal([P0, o_P0])
513+
initial_state_cov = pt.linalg.block_diag(P0, o_P0)
531514
initial_state_cov.name = P0.name
532515

533516
state_intercept = pt.concatenate(conform_time_varying_and_time_invariant_matrices(c, o_c))
@@ -536,19 +519,19 @@ def make_slice(name, x, o_x):
536519
obs_intercept = d + o_d
537520
obs_intercept.name = d.name
538521

539-
transition = block_diagonal([T, o_T])
522+
transition = pt.linalg.block_diag(T, o_T)
540523
transition.name = T.name
541524

542525
design = pt.concatenate(conform_time_varying_and_time_invariant_matrices(Z, o_Z), axis=-1)
543526
design.name = Z.name
544527

545-
selection = block_diagonal([R, o_R])
528+
selection = pt.linalg.block_diag(R, o_R)
546529
selection.name = R.name
547530

548531
obs_cov = H + o_H
549532
obs_cov.name = H.name
550533

551-
state_cov = block_diagonal([Q, o_Q])
534+
state_cov = pt.linalg.block_diag(Q, o_Q)
552535
state_cov.name = Q.name
553536

554537
new_ssm = PytensorRepresentation(
@@ -1326,7 +1309,7 @@ def make_symbolic_graph(self) -> None:
13261309
self.ssm["initial_state", init_state_idx] = init_state
13271310

13281311
T_mats = [_frequency_transition_block(self.season_length, j + 1) for j in range(self.n)]
1329-
T = block_diagonal(T_mats)
1312+
T = pt.linalg.block_diag(*T_mats)
13301313
self.ssm["transition", :, :] = T
13311314

13321315
if self.innovations:

0 commit comments

Comments
 (0)