Skip to content

Commit 28d397b

Browse files
committed
added rewrite for determinant of blockdiag
1 parent 6eefbbe commit 28d397b

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,3 +725,24 @@ def rewrite_diag_blockdiag(fgraph, node):
725725
output = [concatenate(submatrices_diag)]
726726

727727
return output
728+
729+
730+
@register_canonicalize
731+
@register_stabilize
732+
@node_rewriter([det])
733+
def rewrite_det_blockdiag(fgraph, node):
734+
# Check for inner block_diag operation
735+
potential_blockdiag = node.inputs[0].owner
736+
if not (
737+
potential_blockdiag
738+
and isinstance(potential_blockdiag.op, Blockwise)
739+
and isinstance(potential_blockdiag.op.core_op, BlockDiagonal)
740+
):
741+
return None
742+
743+
# Find the composing sub_matrices
744+
sub_matrices = potential_blockdiag.inputs
745+
det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))]
746+
prod_det_sub_matrices = prod(det_sub_matrices)
747+
748+
return [prod_det_sub_matrices]

tests/tensor/rewriting/test_linalg.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,3 +689,30 @@ def test_diag_blockdiag_rewrite():
689689
atol=1e-3 if config.floatX == "float32" else 1e-8,
690690
rtol=1e-3 if config.floatX == "float32" else 1e-8,
691691
)
692+
693+
694+
def test_det_blockdiag_rewrite():
695+
n_matrices = 100
696+
matrix_size = (5, 5)
697+
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
698+
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
699+
det_output = pt.linalg.det(bd_output)
700+
f_rewritten = function([sub_matrices], det_output, mode="FAST_RUN")
701+
702+
# Rewrite Test
703+
nodes = f_rewritten.maker.fgraph.apply_nodes
704+
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
705+
706+
# Value Test
707+
sub_matrices_test = np.random.rand(n_matrices, *matrix_size)
708+
bd_output_test = scipy.linalg.block_diag(
709+
*[sub_matrices_test[i] for i in range(n_matrices)]
710+
)
711+
det_output_test = np.linalg.det(bd_output_test)
712+
rewritten_val = f_rewritten(sub_matrices_test)
713+
assert_allclose(
714+
det_output_test,
715+
rewritten_val,
716+
atol=1e-3 if config.floatX == "float32" else 1e-8,
717+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
718+
)

0 commit comments

Comments
 (0)