Skip to content

Commit 6eefbbe

Browse files
committed
added rewrite for diag(block_diag)
1 parent 1a1c62b commit 6eefbbe

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
from pytensor.scalar.basic import Mul
1313
from pytensor.tensor.basic import (
1414
AllocDiag,
15+
ExtractDiag,
1516
Eye,
1617
TensorVariable,
18+
concatenate,
19+
diag,
1720
diagonal,
1821
)
1922
from pytensor.tensor.blas import Dot22
@@ -701,3 +704,24 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
701704
non_eye_input = pt.shape_padaxis(non_eye_diag, -2)
702705

703706
return [eye_input / non_eye_input]
707+
708+
709+
@register_canonicalize
710+
@register_stabilize
711+
@node_rewriter([ExtractDiag])
712+
def rewrite_diag_blockdiag(fgraph, node):
713+
# Check for inner block_diag operation
714+
potential_blockdiag = node.inputs[0].owner
715+
if not (
716+
potential_blockdiag
717+
and isinstance(potential_blockdiag.op, Blockwise)
718+
and isinstance(potential_blockdiag.op.core_op, BlockDiagonal)
719+
):
720+
return None
721+
722+
# Find the composing sub_matrices
723+
submatrices = potential_blockdiag.inputs
724+
submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))]
725+
output = [concatenate(submatrices_diag)]
726+
727+
return output

tests/tensor/rewriting/test_linalg.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,3 +662,30 @@ def test_inv_diag_from_diag(inv_op):
662662
atol=ATOL,
663663
rtol=RTOL,
664664
)
665+
666+
667+
def test_diag_blockdiag_rewrite():
668+
n_matrices = 100
669+
matrix_size = (5, 5)
670+
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
671+
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
672+
diag_output = pt.diag(bd_output)
673+
f_rewritten = function([sub_matrices], diag_output, mode="FAST_RUN")
674+
675+
# Rewrite Test
676+
nodes = f_rewritten.maker.fgraph.apply_nodes
677+
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
678+
679+
# Value Test
680+
sub_matrices_test = np.random.rand(n_matrices, *matrix_size)
681+
bd_output_test = scipy.linalg.block_diag(
682+
*[sub_matrices_test[i] for i in range(n_matrices)]
683+
)
684+
diag_output_test = np.diag(bd_output_test)
685+
rewritten_val = f_rewritten(sub_matrices_test)
686+
assert_allclose(
687+
diag_output_test,
688+
rewritten_val,
689+
atol=1e-3 if config.floatX == "float32" else 1e-8,
690+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
691+
)

0 commit comments

Comments
 (0)