File tree Expand file tree Collapse file tree 2 files changed +51
-0
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 2 files changed +51
-0
lines changed Original file line number Diff line number Diff line change 12
12
from pytensor .scalar .basic import Mul
13
13
from pytensor .tensor .basic import (
14
14
AllocDiag ,
15
+ ExtractDiag ,
15
16
Eye ,
16
17
TensorVariable ,
18
+ concatenate ,
19
+ diag ,
17
20
diagonal ,
18
21
)
19
22
from pytensor .tensor .blas import Dot22
@@ -701,3 +704,24 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
701
704
non_eye_input = pt .shape_padaxis (non_eye_diag , - 2 )
702
705
703
706
return [eye_input / non_eye_input ]
707
+
708
+
709
+ @register_canonicalize
710
+ @register_stabilize
711
+ @node_rewriter ([ExtractDiag ])
712
+ def rewrite_diag_blockdiag (fgraph , node ):
713
+ # Check for inner block_diag operation
714
+ potential_blockdiag = node .inputs [0 ].owner
715
+ if not (
716
+ potential_blockdiag
717
+ and isinstance (potential_blockdiag .op , Blockwise )
718
+ and isinstance (potential_blockdiag .op .core_op , BlockDiagonal )
719
+ ):
720
+ return None
721
+
722
+ # Find the composing sub_matrices
723
+ submatrices = potential_blockdiag .inputs
724
+ submatrices_diag = [diag (submatrices [i ]) for i in range (len (submatrices ))]
725
+ output = [concatenate (submatrices_diag )]
726
+
727
+ return output
Original file line number Diff line number Diff line change @@ -662,3 +662,30 @@ def test_inv_diag_from_diag(inv_op):
662
662
atol = ATOL ,
663
663
rtol = RTOL ,
664
664
)
665
+
666
+
667
+ def test_diag_blockdiag_rewrite ():
668
+ n_matrices = 100
669
+ matrix_size = (5 , 5 )
670
+ sub_matrices = pt .tensor ("sub_matrices" , shape = (n_matrices , * matrix_size ))
671
+ bd_output = pt .linalg .block_diag (* [sub_matrices [i ] for i in range (n_matrices )])
672
+ diag_output = pt .diag (bd_output )
673
+ f_rewritten = function ([sub_matrices ], diag_output , mode = "FAST_RUN" )
674
+
675
+ # Rewrite Test
676
+ nodes = f_rewritten .maker .fgraph .apply_nodes
677
+ assert not any (isinstance (node .op , BlockDiagonal ) for node in nodes )
678
+
679
+ # Value Test
680
+ sub_matrices_test = np .random .rand (n_matrices , * matrix_size )
681
+ bd_output_test = scipy .linalg .block_diag (
682
+ * [sub_matrices_test [i ] for i in range (n_matrices )]
683
+ )
684
+ diag_output_test = np .diag (bd_output_test )
685
+ rewritten_val = f_rewritten (sub_matrices_test )
686
+ assert_allclose (
687
+ diag_output_test ,
688
+ rewritten_val ,
689
+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
690
+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
691
+ )
You can’t perform that action at this time.
0 commit comments