File tree 2 files changed +48
-0
lines changed
pytensor/tensor/rewriting 2 files changed +48
-0
lines changed Original file line number Diff line number Diff line change @@ -725,3 +725,24 @@ def rewrite_diag_blockdiag(fgraph, node):
725
725
output = [concatenate (submatrices_diag )]
726
726
727
727
return output
728
+
729
+
730
+ @register_canonicalize
731
+ @register_stabilize
732
+ @node_rewriter ([det ])
733
+ def rewrite_det_blockdiag (fgraph , node ):
734
+ # Check for inner block_diag operation
735
+ potential_blockdiag = node .inputs [0 ].owner
736
+ if not (
737
+ potential_blockdiag
738
+ and isinstance (potential_blockdiag .op , Blockwise )
739
+ and isinstance (potential_blockdiag .op .core_op , BlockDiagonal )
740
+ ):
741
+ return None
742
+
743
+ # Find the composing sub_matrices
744
+ sub_matrices = potential_blockdiag .inputs
745
+ det_sub_matrices = [det (sub_matrices [i ]) for i in range (len (sub_matrices ))]
746
+ prod_det_sub_matrices = prod (det_sub_matrices )
747
+
748
+ return [prod_det_sub_matrices ]
Original file line number Diff line number Diff line change @@ -689,3 +689,30 @@ def test_diag_blockdiag_rewrite():
689
689
atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
690
690
rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
691
691
)
692
+
693
+
694
+ def test_det_blockdiag_rewrite ():
695
+ n_matrices = 100
696
+ matrix_size = (5 , 5 )
697
+ sub_matrices = pt .tensor ("sub_matrices" , shape = (n_matrices , * matrix_size ))
698
+ bd_output = pt .linalg .block_diag (* [sub_matrices [i ] for i in range (n_matrices )])
699
+ det_output = pt .linalg .det (bd_output )
700
+ f_rewritten = function ([sub_matrices ], det_output , mode = "FAST_RUN" )
701
+
702
+ # Rewrite Test
703
+ nodes = f_rewritten .maker .fgraph .apply_nodes
704
+ assert not any (isinstance (node .op , BlockDiagonal ) for node in nodes )
705
+
706
+ # Value Test
707
+ sub_matrices_test = np .random .rand (n_matrices , * matrix_size )
708
+ bd_output_test = scipy .linalg .block_diag (
709
+ * [sub_matrices_test [i ] for i in range (n_matrices )]
710
+ )
711
+ det_output_test = np .linalg .det (bd_output_test )
712
+ rewritten_val = f_rewritten (sub_matrices_test )
713
+ assert_allclose (
714
+ det_output_test ,
715
+ rewritten_val ,
716
+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
717
+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
718
+ )
You can’t perform that action at this time.
0 commit comments