File tree 2 files changed +48
-0
lines changed
pytensor/tensor/rewriting 2 files changed +48
-0
lines changed Original file line number Diff line number Diff line change @@ -635,3 +635,24 @@ def rewrite_diag_blockdiag(fgraph, node):
635
635
output = [concatenate (submatrices_diag )]
636
636
637
637
return output
638
+
639
+
640
+ @register_canonicalize
641
+ @register_stabilize
642
+ @node_rewriter ([det ])
643
+ def rewrite_det_blockdiag (fgraph , node ):
644
+ # Check for inner block_diag operation
645
+ potential_blockdiag = node .inputs [0 ].owner
646
+ if not (
647
+ potential_blockdiag
648
+ and isinstance (potential_blockdiag .op , Blockwise )
649
+ and isinstance (potential_blockdiag .op .core_op , BlockDiagonal )
650
+ ):
651
+ return None
652
+
653
+ # Find the composing sub_matrices
654
+ sub_matrices = potential_blockdiag .inputs
655
+ det_sub_matrices = [det (sub_matrices [i ]) for i in range (len (sub_matrices ))]
656
+ prod_det_sub_matrices = prod (det_sub_matrices )
657
+
658
+ return [prod_det_sub_matrices ]
Original file line number Diff line number Diff line change @@ -595,3 +595,30 @@ def test_diag_blockdiag_rewrite():
595
595
atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
596
596
rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
597
597
)
598
+
599
+
600
+ def test_det_blockdiag_rewrite ():
601
+ n_matrices = 100
602
+ matrix_size = (5 , 5 )
603
+ sub_matrices = pt .tensor ("sub_matrices" , shape = (n_matrices , * matrix_size ))
604
+ bd_output = pt .linalg .block_diag (* [sub_matrices [i ] for i in range (n_matrices )])
605
+ det_output = pt .linalg .det (bd_output )
606
+ f_rewritten = function ([sub_matrices ], det_output , mode = "FAST_RUN" )
607
+
608
+ # Rewrite Test
609
+ nodes = f_rewritten .maker .fgraph .apply_nodes
610
+ assert not any (isinstance (node .op , BlockDiagonal ) for node in nodes )
611
+
612
+ # Value Test
613
+ sub_matrices_test = np .random .rand (n_matrices , * matrix_size )
614
+ bd_output_test = scipy .linalg .block_diag (
615
+ * [sub_matrices_test [i ] for i in range (n_matrices )]
616
+ )
617
+ det_output_test = np .linalg .det (bd_output_test )
618
+ rewritten_val = f_rewritten (sub_matrices_test )
619
+ assert_allclose (
620
+ det_output_test ,
621
+ rewritten_val ,
622
+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
623
+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
624
+ )
You can’t perform that action at this time.
0 commit comments