Skip to content

Commit c0892e0

Browse files
tanish1729jessegrabowski
authored andcommitted
minor changes; added test to not apply rewrite
1 parent 7074e50 commit c0892e0

File tree

2 files changed

+25
-21
lines changed

2 files changed

+25
-21
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from collections.abc import Callable
33
from typing import cast
44

5-
import pytensor.tensor as pt
65
from pytensor import Variable
76
from pytensor import tensor as pt
87
from pytensor.graph import Apply, FunctionGraph
@@ -824,7 +823,7 @@ def rewrite_slogdet_blockdiag(fgraph, node):
824823
@register_canonicalize
825824
@register_stabilize
826825
@node_rewriter([Blockwise])
827-
def rewrite_cholesky_eye_to_eye(fgraph, node):
826+
def rewrite_remove_useless_cholesky(fgraph, node):
828827
"""
829828
This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself
830829
@@ -847,14 +846,15 @@ def rewrite_cholesky_eye_to_eye(fgraph, node):
847846
return None
848847

849848
# Check whether input to Cholesky is Eye and the 1's are on main diagonal
850-
eye_check = node.inputs[0]
849+
potential_eye = node.inputs[0]
851850
if not (
852-
eye_check.owner
853-
and isinstance(eye_check.owner.op, Eye)
854-
and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0
851+
potential_eye.owner
852+
and isinstance(potential_eye.owner.op, Eye)
853+
and hasattr(potential_eye.owner.inputs[-1], "data")
854+
and potential_eye.owner.inputs[-1].data.item() == 0
855855
):
856856
return None
857-
return [eye_check]
857+
return [potential_eye]
858858

859859

860860
@register_canonicalize
@@ -872,10 +872,9 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
872872
and isinstance(inputs.owner.op, AllocDiag)
873873
and AllocDiag.is_offset_zero(inputs.owner)
874874
):
875-
cholesky_input = inputs.owner.inputs[0]
876-
if cholesky_input.type.ndim == 1:
877-
cholesky_val = pt.diag(cholesky_input**0.5)
878-
return [cholesky_val]
875+
diag_input = inputs.owner.inputs[0]
876+
cholesky_val = pt.diag(diag_input**0.5)
877+
return [cholesky_val]
879878

880879
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
881880
inputs_or_none = _find_diag_from_eye_mul(inputs)
@@ -893,8 +892,6 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
893892
# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
894893
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
895894
if non_eye_input.type.broadcastable[-2:] == (False, False):
896-
# For Matrix
897-
return [eye_input * (non_eye_input.diagonal(axis1=-1, axis2=-2) ** 0.5)]
898-
else:
899-
# For Vector or Scalar
900-
return [eye_input * (non_eye_input**0.5)]
895+
non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2)
896+
897+
return [eye_input * (non_eye_input**0.5)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -755,17 +755,12 @@ def test_slogdet_blockdiag_rewrite():
755755

756756
def test_cholesky_eye_rewrite():
757757
x = pt.eye(10)
758-
x_mat = pt.matrix("x")
759758
L = pt.linalg.cholesky(x)
760-
L_mat = pt.linalg.cholesky(x_mat)
761759
f_rewritten = function([], L, mode="FAST_RUN")
762-
f_rewritten_mat = function([x_mat], L_mat, mode="FAST_RUN")
763760
nodes = f_rewritten.maker.fgraph.apply_nodes
764-
nodes_mat = f_rewritten_mat.maker.fgraph.apply_nodes
765761

766762
# Rewrite Test
767763
assert not any(isinstance(node.op, Cholesky) for node in nodes)
768-
assert any(isinstance(node.op, Cholesky) for node in nodes_mat)
769764

770765
# Value Test
771766
x_test = np.eye(10)
@@ -839,3 +834,15 @@ def test_cholesky_diag_from_diag():
839834
atol=1e-3 if config.floatX == "float32" else 1e-8,
840835
rtol=1e-3 if config.floatX == "float32" else 1e-8,
841836
)
837+
838+
839+
def test_dont_apply_cholesky():
840+
x = pt.tensor("x", shape=(7, 7))
841+
y = pt.eye(7, k=-1) * x
842+
# Here, y is not a diagonal matrix because of k = -1
843+
z_cholesky = pt.linalg.cholesky(y)
844+
845+
# REWRITE TEST (should not be applied)
846+
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
847+
nodes = f_rewritten.maker.fgraph.apply_nodes
848+
assert any(isinstance(node.op, Cholesky) for node in nodes)

0 commit comments

Comments
 (0)