Skip to content

Commit bc30db8

Browse files
committed
added rewrite for determinant of blockdiag
1 parent c0c9163 commit bc30db8

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,3 +635,24 @@ def rewrite_diag_blockdiag(fgraph, node):
635635
output = [concatenate(submatrices_diag)]
636636

637637
return output
638+
639+
640+
@register_canonicalize
641+
@register_stabilize
642+
@node_rewriter([det])
643+
def rewrite_det_blockdiag(fgraph, node):
644+
# Check for inner block_diag operation
645+
potential_blockdiag = node.inputs[0].owner
646+
if not (
647+
potential_blockdiag
648+
and isinstance(potential_blockdiag.op, Blockwise)
649+
and isinstance(potential_blockdiag.op.core_op, BlockDiagonal)
650+
):
651+
return None
652+
653+
# Find the composing sub_matrices
654+
sub_matrices = potential_blockdiag.inputs
655+
det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))]
656+
prod_det_sub_matrices = prod(det_sub_matrices)
657+
658+
return [prod_det_sub_matrices]

tests/tensor/rewriting/test_linalg.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,3 +595,30 @@ def test_diag_blockdiag_rewrite():
595595
atol=1e-3 if config.floatX == "float32" else 1e-8,
596596
rtol=1e-3 if config.floatX == "float32" else 1e-8,
597597
)
598+
599+
600+
def test_det_blockdiag_rewrite():
601+
n_matrices = 100
602+
matrix_size = (5, 5)
603+
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
604+
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
605+
det_output = pt.linalg.det(bd_output)
606+
f_rewritten = function([sub_matrices], det_output, mode="FAST_RUN")
607+
608+
# Rewrite Test
609+
nodes = f_rewritten.maker.fgraph.apply_nodes
610+
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
611+
612+
# Value Test
613+
sub_matrices_test = np.random.rand(n_matrices, *matrix_size)
614+
bd_output_test = scipy.linalg.block_diag(
615+
*[sub_matrices_test[i] for i in range(n_matrices)]
616+
)
617+
det_output_test = np.linalg.det(bd_output_test)
618+
rewritten_val = f_rewritten(sub_matrices_test)
619+
assert_allclose(
620+
det_output_test,
621+
rewritten_val,
622+
atol=1e-3 if config.floatX == "float32" else 1e-8,
623+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
624+
)

0 commit comments

Comments
 (0)