Skip to content

Commit 81d4748

Browse files
committed
Added rewrite for slogdet; added docstrings for all 3 rewrites
1 parent bc30db8 commit 81d4748

File tree

2 files changed

+122
-15
lines changed

2 files changed

+122
-15
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 86 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
inv,
3232
kron,
3333
pinv,
34+
slogdet,
3435
svd,
3536
)
3637
from pytensor.tensor.rewriting.basic import (
@@ -620,39 +621,110 @@ def rewrite_inv_inv(fgraph, node):
620621
@register_stabilize
621622
@node_rewriter([ExtractDiag])
622623
def rewrite_diag_blockdiag(fgraph, node):
624+
"""
625+
This rewrite simplifies extracting the diagonal of a blockdiagonal matrix by concatening the diagonal values of all of the individual sub matrices.
626+
627+
diag(block_diag(a,b,c,....)) = concat(diag(a), diag(b), diag(c),...)
628+
629+
Parameters
630+
----------
631+
fgraph: FunctionGraph
632+
Function graph being optimized
633+
node: Apply
634+
Node of the function graph to be optimized
635+
636+
Returns
637+
-------
638+
list of Variable, optional
639+
List of optimized variables, or None if no optimization was performed
640+
"""
623641
# Check for inner block_diag operation
624-
potential_blockdiag = node.inputs[0].owner
642+
potential_block_diag = node.inputs[0].owner
625643
if not (
626-
potential_blockdiag
627-
and isinstance(potential_blockdiag.op, Blockwise)
628-
and isinstance(potential_blockdiag.op.core_op, BlockDiagonal)
644+
potential_block_diag
645+
and isinstance(potential_block_diag.op, Blockwise)
646+
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
629647
):
630648
return None
631649

632650
# Find the composing sub_matrices
633-
submatrices = potential_blockdiag.inputs
651+
submatrices = potential_block_diag.inputs
634652
submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))]
635-
output = [concatenate(submatrices_diag)]
636653

637-
return output
654+
return [concatenate(submatrices_diag)]
638655

639656

640657
@register_canonicalize
641658
@register_stabilize
642659
@node_rewriter([det])
643660
def rewrite_det_blockdiag(fgraph, node):
661+
"""
662+
This rewrite simplifies the determinant of a blockdiagonal matrix by extracting the individual sub matrices and returning the product of all individual determinant values.
663+
664+
det(block_diag(a,b,c,....)) = prod(det(a), det(b), det(c),...)
665+
666+
Parameters
667+
----------
668+
fgraph: FunctionGraph
669+
Function graph being optimized
670+
node: Apply
671+
Node of the function graph to be optimized
672+
673+
Returns
674+
-------
675+
list of Variable, optional
676+
List of optimized variables, or None if no optimization was performed
677+
"""
644678
# Check for inner block_diag operation
645-
potential_blockdiag = node.inputs[0].owner
679+
potential_block_diag = node.inputs[0].owner
646680
if not (
647-
potential_blockdiag
648-
and isinstance(potential_blockdiag.op, Blockwise)
649-
and isinstance(potential_blockdiag.op.core_op, BlockDiagonal)
681+
potential_block_diag
682+
and isinstance(potential_block_diag.op, Blockwise)
683+
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
650684
):
651685
return None
652686

653687
# Find the composing sub_matrices
654-
sub_matrices = potential_blockdiag.inputs
688+
sub_matrices = potential_block_diag.inputs
655689
det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))]
656-
prod_det_sub_matrices = prod(det_sub_matrices)
657690

658-
return [prod_det_sub_matrices]
691+
return [prod(det_sub_matrices)]
692+
693+
694+
@register_canonicalize
695+
@register_stabilize
696+
@node_rewriter([slogdet])
697+
def rewrite_slogdet_blockdiag(fgraph, node):
698+
"""
699+
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
700+
701+
slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)
702+
703+
Parameters
704+
----------
705+
fgraph: FunctionGraph
706+
Function graph being optimized
707+
node: Apply
708+
Node of the function graph to be optimized
709+
710+
Returns
711+
-------
712+
list of Variable, optional
713+
List of optimized variables, or None if no optimization was performed
714+
"""
715+
# Check for inner block_diag operation
716+
potential_block_diag = node.inputs[0].owner
717+
if not (
718+
potential_block_diag
719+
and isinstance(potential_block_diag.op, Blockwise)
720+
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
721+
):
722+
return None
723+
724+
# Find the composing sub_matrices
725+
sub_matrices = potential_block_diag.inputs
726+
sign_sub_matrices, logdet_sub_matrices = zip(
727+
*[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
728+
)
729+
730+
return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def get_pt_function(x, op_name):
571571

572572

573573
def test_diag_blockdiag_rewrite():
574-
n_matrices = 100
574+
n_matrices = 10
575575
matrix_size = (5, 5)
576576
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
577577
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
@@ -622,3 +622,38 @@ def test_det_blockdiag_rewrite():
622622
atol=1e-3 if config.floatX == "float32" else 1e-8,
623623
rtol=1e-3 if config.floatX == "float32" else 1e-8,
624624
)
625+
626+
627+
def test_slogdet_blockdiag_rewrite():
628+
n_matrices = 100
629+
matrix_size = (5, 5)
630+
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
631+
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
632+
sign_output, logdet_output = pt.linalg.slogdet(bd_output)
633+
f_rewritten = function(
634+
[sub_matrices], [sign_output, logdet_output], mode="FAST_RUN"
635+
)
636+
637+
# Rewrite Test
638+
nodes = f_rewritten.maker.fgraph.apply_nodes
639+
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
640+
641+
# Value Test
642+
sub_matrices_test = np.random.rand(n_matrices, *matrix_size)
643+
bd_output_test = scipy.linalg.block_diag(
644+
*[sub_matrices_test[i] for i in range(n_matrices)]
645+
)
646+
sign_output_test, logdet_output_test = np.linalg.slogdet(bd_output_test)
647+
rewritten_sign_val, rewritten_logdet_val = f_rewritten(sub_matrices_test)
648+
assert_allclose(
649+
sign_output_test,
650+
rewritten_sign_val,
651+
atol=1e-3 if config.floatX == "float32" else 1e-8,
652+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
653+
)
654+
assert_allclose(
655+
logdet_output_test,
656+
rewritten_logdet_val,
657+
atol=1e-3 if config.floatX == "float32" else 1e-8,
658+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
659+
)

0 commit comments

Comments
 (0)