@@ -44,23 +44,6 @@ def _frequency_transition_block(s, j):
44
44
return pt .stack ([[pt .cos (lam ), pt .sin (lam )], [- pt .sin (lam ), pt .cos (lam )]]).squeeze ()
45
45
46
46
47
- def block_diagonal (matrices : list [pt .matrix ]):
48
- rows = [x .shape [0 ] for x in matrices ]
49
- cols = [x .shape [1 ] for x in matrices ]
50
- out = pt .zeros ((sum (rows ), sum (cols )))
51
- row_cursor = 0
52
- col_cursor = 0
53
-
54
- for row , col , mat in zip (rows , cols , matrices ):
55
- row_slice = slice (row_cursor , row_cursor + row )
56
- col_slice = slice (col_cursor , col_cursor + col )
57
- row_cursor += row
58
- col_cursor += col
59
-
60
- out = pt .set_subtensor (out [row_slice , col_slice ], mat )
61
- return out
62
-
63
-
64
47
class StructuralTimeSeries (PyMCStateSpace ):
65
48
r"""
66
49
Structural Time Series Model
@@ -527,7 +510,7 @@ def make_slice(name, x, o_x):
527
510
initial_state = pt .concatenate (conform_time_varying_and_time_invariant_matrices (x0 , o_x0 ))
528
511
initial_state .name = x0 .name
529
512
530
- initial_state_cov = block_diagonal ([ P0 , o_P0 ] )
513
+ initial_state_cov = pt . linalg . block_diag ( P0 , o_P0 )
531
514
initial_state_cov .name = P0 .name
532
515
533
516
state_intercept = pt .concatenate (conform_time_varying_and_time_invariant_matrices (c , o_c ))
@@ -536,19 +519,19 @@ def make_slice(name, x, o_x):
536
519
obs_intercept = d + o_d
537
520
obs_intercept .name = d .name
538
521
539
- transition = block_diagonal ([ T , o_T ] )
522
+ transition = pt . linalg . block_diag ( T , o_T )
540
523
transition .name = T .name
541
524
542
525
design = pt .concatenate (conform_time_varying_and_time_invariant_matrices (Z , o_Z ), axis = - 1 )
543
526
design .name = Z .name
544
527
545
- selection = block_diagonal ([ R , o_R ] )
528
+ selection = pt . linalg . block_diag ( R , o_R )
546
529
selection .name = R .name
547
530
548
531
obs_cov = H + o_H
549
532
obs_cov .name = H .name
550
533
551
- state_cov = block_diagonal ([ Q , o_Q ] )
534
+ state_cov = pt . linalg . block_diag ( Q , o_Q )
552
535
state_cov .name = Q .name
553
536
554
537
new_ssm = PytensorRepresentation (
@@ -1326,7 +1309,7 @@ def make_symbolic_graph(self) -> None:
1326
1309
self .ssm ["initial_state" , init_state_idx ] = init_state
1327
1310
1328
1311
T_mats = [_frequency_transition_block (self .season_length , j + 1 ) for j in range (self .n )]
1329
- T = block_diagonal ( T_mats )
1312
+ T = pt . linalg . block_diag ( * T_mats )
1330
1313
self .ssm ["transition" , :, :] = T
1331
1314
1332
1315
if self .innovations :
0 commit comments