Skip to content

Commit 48f6527

Browse files
committed
Added rewrite for slogdet; added docstrings for all 3 rewrites
1 parent 28d397b commit 48f6527

File tree

2 files changed

+122
-15
lines changed

2 files changed

+122
-15
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 86 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
inv,
3333
kron,
3434
pinv,
35+
slogdet,
3536
svd,
3637
)
3738
from pytensor.tensor.rewriting.basic import (
@@ -710,39 +711,110 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
710711
@register_stabilize
711712
@node_rewriter([ExtractDiag])
712713
def rewrite_diag_blockdiag(fgraph, node):
714+
"""
715+
This rewrite simplifies extracting the diagonal of a blockdiagonal matrix by concatening the diagonal values of all of the individual sub matrices.
716+
717+
diag(block_diag(a,b,c,....)) = concat(diag(a), diag(b), diag(c),...)
718+
719+
Parameters
720+
----------
721+
fgraph: FunctionGraph
722+
Function graph being optimized
723+
node: Apply
724+
Node of the function graph to be optimized
725+
726+
Returns
727+
-------
728+
list of Variable, optional
729+
List of optimized variables, or None if no optimization was performed
730+
"""
713731
# Check for inner block_diag operation
714-
potential_blockdiag = node.inputs[0].owner
732+
potential_block_diag = node.inputs[0].owner
715733
if not (
716-
potential_blockdiag
717-
and isinstance(potential_blockdiag.op, Blockwise)
718-
and isinstance(potential_blockdiag.op.core_op, BlockDiagonal)
734+
potential_block_diag
735+
and isinstance(potential_block_diag.op, Blockwise)
736+
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
719737
):
720738
return None
721739

722740
# Find the composing sub_matrices
723-
submatrices = potential_blockdiag.inputs
741+
submatrices = potential_block_diag.inputs
724742
submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))]
725-
output = [concatenate(submatrices_diag)]
726743

727-
return output
744+
return [concatenate(submatrices_diag)]
728745

729746

730747
@register_canonicalize
731748
@register_stabilize
732749
@node_rewriter([det])
733750
def rewrite_det_blockdiag(fgraph, node):
751+
"""
752+
This rewrite simplifies the determinant of a blockdiagonal matrix by extracting the individual sub matrices and returning the product of all individual determinant values.
753+
754+
det(block_diag(a,b,c,....)) = prod(det(a), det(b), det(c),...)
755+
756+
Parameters
757+
----------
758+
fgraph: FunctionGraph
759+
Function graph being optimized
760+
node: Apply
761+
Node of the function graph to be optimized
762+
763+
Returns
764+
-------
765+
list of Variable, optional
766+
List of optimized variables, or None if no optimization was performed
767+
"""
734768
# Check for inner block_diag operation
735-
potential_blockdiag = node.inputs[0].owner
769+
potential_block_diag = node.inputs[0].owner
736770
if not (
737-
potential_blockdiag
738-
and isinstance(potential_blockdiag.op, Blockwise)
739-
and isinstance(potential_blockdiag.op.core_op, BlockDiagonal)
771+
potential_block_diag
772+
and isinstance(potential_block_diag.op, Blockwise)
773+
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
740774
):
741775
return None
742776

743777
# Find the composing sub_matrices
744-
sub_matrices = potential_blockdiag.inputs
778+
sub_matrices = potential_block_diag.inputs
745779
det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))]
746-
prod_det_sub_matrices = prod(det_sub_matrices)
747780

748-
return [prod_det_sub_matrices]
781+
return [prod(det_sub_matrices)]
782+
783+
784+
@register_canonicalize
785+
@register_stabilize
786+
@node_rewriter([slogdet])
787+
def rewrite_slogdet_blockdiag(fgraph, node):
788+
"""
789+
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
790+
791+
slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)
792+
793+
Parameters
794+
----------
795+
fgraph: FunctionGraph
796+
Function graph being optimized
797+
node: Apply
798+
Node of the function graph to be optimized
799+
800+
Returns
801+
-------
802+
list of Variable, optional
803+
List of optimized variables, or None if no optimization was performed
804+
"""
805+
# Check for inner block_diag operation
806+
potential_block_diag = node.inputs[0].owner
807+
if not (
808+
potential_block_diag
809+
and isinstance(potential_block_diag.op, Blockwise)
810+
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
811+
):
812+
return None
813+
814+
# Find the composing sub_matrices
815+
sub_matrices = potential_block_diag.inputs
816+
sign_sub_matrices, logdet_sub_matrices = zip(
817+
*[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
818+
)
819+
820+
return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def test_inv_diag_from_diag(inv_op):
665665

666666

667667
def test_diag_blockdiag_rewrite():
668-
n_matrices = 100
668+
n_matrices = 10
669669
matrix_size = (5, 5)
670670
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
671671
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
@@ -716,3 +716,38 @@ def test_det_blockdiag_rewrite():
716716
atol=1e-3 if config.floatX == "float32" else 1e-8,
717717
rtol=1e-3 if config.floatX == "float32" else 1e-8,
718718
)
719+
720+
721+
def test_slogdet_blockdiag_rewrite():
722+
n_matrices = 100
723+
matrix_size = (5, 5)
724+
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
725+
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
726+
sign_output, logdet_output = pt.linalg.slogdet(bd_output)
727+
f_rewritten = function(
728+
[sub_matrices], [sign_output, logdet_output], mode="FAST_RUN"
729+
)
730+
731+
# Rewrite Test
732+
nodes = f_rewritten.maker.fgraph.apply_nodes
733+
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
734+
735+
# Value Test
736+
sub_matrices_test = np.random.rand(n_matrices, *matrix_size)
737+
bd_output_test = scipy.linalg.block_diag(
738+
*[sub_matrices_test[i] for i in range(n_matrices)]
739+
)
740+
sign_output_test, logdet_output_test = np.linalg.slogdet(bd_output_test)
741+
rewritten_sign_val, rewritten_logdet_val = f_rewritten(sub_matrices_test)
742+
assert_allclose(
743+
sign_output_test,
744+
rewritten_sign_val,
745+
atol=1e-3 if config.floatX == "float32" else 1e-8,
746+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
747+
)
748+
assert_allclose(
749+
logdet_output_test,
750+
rewritten_logdet_val,
751+
atol=1e-3 if config.floatX == "float32" else 1e-8,
752+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
753+
)

0 commit comments

Comments
 (0)