Skip to content

Commit 3eea7d0

Browse files
authored
Added rewrites involving block diagonal matrices (#967)
* added rewrite for diag(block_diag) * added rewrite for determinant of blockdiag * Added rewrite for slogdet; added docstrings for all 3 rewrites * fixed typecasting for tests
1 parent 2086aeb commit 3eea7d0

File tree

2 files changed

+206
-0
lines changed

2 files changed

+206
-0
lines changed

pytensor/tensor/rewriting/linalg.py

+117
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
from pytensor.scalar.basic import Mul
1313
from pytensor.tensor.basic import (
1414
AllocDiag,
15+
ExtractDiag,
1516
Eye,
1617
TensorVariable,
18+
concatenate,
19+
diag,
1720
diagonal,
1821
)
1922
from pytensor.tensor.blas import Dot22
@@ -29,6 +32,7 @@
2932
inv,
3033
kron,
3134
pinv,
35+
slogdet,
3236
svd,
3337
)
3438
from pytensor.tensor.rewriting.basic import (
@@ -701,3 +705,116 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
701705
non_eye_input = pt.shape_padaxis(non_eye_diag, -2)
702706

703707
return [eye_input / non_eye_input]
708+
709+
710+
@register_canonicalize
711+
@register_stabilize
712+
@node_rewriter([ExtractDiag])
713+
def rewrite_diag_blockdiag(fgraph, node):
714+
"""
715+
This rewrite simplifies extracting the diagonal of a blockdiagonal matrix by concatening the diagonal values of all of the individual sub matrices.
716+
717+
diag(block_diag(a,b,c,....)) = concat(diag(a), diag(b), diag(c),...)
718+
719+
Parameters
720+
----------
721+
fgraph: FunctionGraph
722+
Function graph being optimized
723+
node: Apply
724+
Node of the function graph to be optimized
725+
726+
Returns
727+
-------
728+
list of Variable, optional
729+
List of optimized variables, or None if no optimization was performed
730+
"""
731+
# Check for inner block_diag operation
732+
potential_block_diag = node.inputs[0].owner
733+
if not (
734+
potential_block_diag
735+
and isinstance(potential_block_diag.op, Blockwise)
736+
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
737+
):
738+
return None
739+
740+
# Find the composing sub_matrices
741+
submatrices = potential_block_diag.inputs
742+
submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))]
743+
744+
return [concatenate(submatrices_diag)]
745+
746+
747+
@register_canonicalize
748+
@register_stabilize
749+
@node_rewriter([det])
750+
def rewrite_det_blockdiag(fgraph, node):
751+
"""
752+
This rewrite simplifies the determinant of a blockdiagonal matrix by extracting the individual sub matrices and returning the product of all individual determinant values.
753+
754+
det(block_diag(a,b,c,....)) = prod(det(a), det(b), det(c),...)
755+
756+
Parameters
757+
----------
758+
fgraph: FunctionGraph
759+
Function graph being optimized
760+
node: Apply
761+
Node of the function graph to be optimized
762+
763+
Returns
764+
-------
765+
list of Variable, optional
766+
List of optimized variables, or None if no optimization was performed
767+
"""
768+
# Check for inner block_diag operation
769+
potential_block_diag = node.inputs[0].owner
770+
if not (
771+
potential_block_diag
772+
and isinstance(potential_block_diag.op, Blockwise)
773+
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
774+
):
775+
return None
776+
777+
# Find the composing sub_matrices
778+
sub_matrices = potential_block_diag.inputs
779+
det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))]
780+
781+
return [prod(det_sub_matrices)]
782+
783+
784+
@register_canonicalize
785+
@register_stabilize
786+
@node_rewriter([slogdet])
787+
def rewrite_slogdet_blockdiag(fgraph, node):
788+
"""
789+
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
790+
791+
slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)
792+
793+
Parameters
794+
----------
795+
fgraph: FunctionGraph
796+
Function graph being optimized
797+
node: Apply
798+
Node of the function graph to be optimized
799+
800+
Returns
801+
-------
802+
list of Variable, optional
803+
List of optimized variables, or None if no optimization was performed
804+
"""
805+
# Check for inner block_diag operation
806+
potential_block_diag = node.inputs[0].owner
807+
if not (
808+
potential_block_diag
809+
and isinstance(potential_block_diag.op, Blockwise)
810+
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
811+
):
812+
return None
813+
814+
# Find the composing sub_matrices
815+
sub_matrices = potential_block_diag.inputs
816+
sign_sub_matrices, logdet_sub_matrices = zip(
817+
*[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
818+
)
819+
820+
return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]

tests/tensor/rewriting/test_linalg.py

+89
Original file line numberDiff line numberDiff line change
@@ -662,3 +662,92 @@ def test_inv_diag_from_diag(inv_op):
662662
atol=ATOL,
663663
rtol=RTOL,
664664
)
665+
666+
667+
def test_diag_blockdiag_rewrite():
668+
n_matrices = 10
669+
matrix_size = (5, 5)
670+
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
671+
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
672+
diag_output = pt.diag(bd_output)
673+
f_rewritten = function([sub_matrices], diag_output, mode="FAST_RUN")
674+
675+
# Rewrite Test
676+
nodes = f_rewritten.maker.fgraph.apply_nodes
677+
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
678+
679+
# Value Test
680+
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
681+
bd_output_test = scipy.linalg.block_diag(
682+
*[sub_matrices_test[i] for i in range(n_matrices)]
683+
)
684+
diag_output_test = np.diag(bd_output_test)
685+
rewritten_val = f_rewritten(sub_matrices_test)
686+
assert_allclose(
687+
diag_output_test,
688+
rewritten_val,
689+
atol=1e-3 if config.floatX == "float32" else 1e-8,
690+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
691+
)
692+
693+
694+
def test_det_blockdiag_rewrite():
695+
n_matrices = 100
696+
matrix_size = (5, 5)
697+
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
698+
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
699+
det_output = pt.linalg.det(bd_output)
700+
f_rewritten = function([sub_matrices], det_output, mode="FAST_RUN")
701+
702+
# Rewrite Test
703+
nodes = f_rewritten.maker.fgraph.apply_nodes
704+
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
705+
706+
# Value Test
707+
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
708+
bd_output_test = scipy.linalg.block_diag(
709+
*[sub_matrices_test[i] for i in range(n_matrices)]
710+
)
711+
det_output_test = np.linalg.det(bd_output_test)
712+
rewritten_val = f_rewritten(sub_matrices_test)
713+
assert_allclose(
714+
det_output_test,
715+
rewritten_val,
716+
atol=1e-3 if config.floatX == "float32" else 1e-8,
717+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
718+
)
719+
720+
721+
def test_slogdet_blockdiag_rewrite():
722+
n_matrices = 100
723+
matrix_size = (5, 5)
724+
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
725+
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
726+
sign_output, logdet_output = pt.linalg.slogdet(bd_output)
727+
f_rewritten = function(
728+
[sub_matrices], [sign_output, logdet_output], mode="FAST_RUN"
729+
)
730+
731+
# Rewrite Test
732+
nodes = f_rewritten.maker.fgraph.apply_nodes
733+
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
734+
735+
# Value Test
736+
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
737+
bd_output_test = scipy.linalg.block_diag(
738+
*[sub_matrices_test[i] for i in range(n_matrices)]
739+
)
740+
sign_output_test, logdet_output_test = np.linalg.slogdet(bd_output_test)
741+
rewritten_sign_val, rewritten_logdet_val = f_rewritten(sub_matrices_test)
742+
assert_allclose(
743+
sign_output_test,
744+
rewritten_sign_val,
745+
atol=1e-3 if config.floatX == "float32" else 1e-8,
746+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
747+
)
748+
assert_allclose(
749+
logdet_output_test,
750+
rewritten_logdet_val,
751+
atol=1e-3 if config.floatX == "float32" else 1e-8,
752+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
753+
)

0 commit comments

Comments
 (0)