Skip to content

Commit c54f3c1

Browse files
tanish1729jessegrabowski
authored andcommitted
minor changes; added test to not apply rewrite
1 parent d573d57 commit c54f3c1

File tree

2 files changed

+25
-21
lines changed

2 files changed

+25
-21
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from collections.abc import Callable
33
from typing import cast
44

5-
import pytensor.tensor as pt
65
from pytensor import Variable
76
from pytensor import tensor as pt
87
from pytensor.graph import Apply, FunctionGraph
@@ -893,7 +892,7 @@ def rewrite_slogdet_kronecker(fgraph, node):
893892
@register_canonicalize
894893
@register_stabilize
895894
@node_rewriter([Blockwise])
896-
def rewrite_cholesky_eye_to_eye(fgraph, node):
895+
def rewrite_remove_useless_cholesky(fgraph, node):
897896
"""
898897
This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself
899898
@@ -916,14 +915,15 @@ def rewrite_cholesky_eye_to_eye(fgraph, node):
916915
return None
917916

918917
# Check whether input to Cholesky is Eye and the 1's are on main diagonal
919-
eye_check = node.inputs[0]
918+
potential_eye = node.inputs[0]
920919
if not (
921-
eye_check.owner
922-
and isinstance(eye_check.owner.op, Eye)
923-
and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0
920+
potential_eye.owner
921+
and isinstance(potential_eye.owner.op, Eye)
922+
and hasattr(potential_eye.owner.inputs[-1], "data")
923+
and potential_eye.owner.inputs[-1].data.item() == 0
924924
):
925925
return None
926-
return [eye_check]
926+
return [potential_eye]
927927

928928

929929
@register_canonicalize
@@ -941,10 +941,9 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
941941
and isinstance(inputs.owner.op, AllocDiag)
942942
and AllocDiag.is_offset_zero(inputs.owner)
943943
):
944-
cholesky_input = inputs.owner.inputs[0]
945-
if cholesky_input.type.ndim == 1:
946-
cholesky_val = pt.diag(cholesky_input**0.5)
947-
return [cholesky_val]
944+
diag_input = inputs.owner.inputs[0]
945+
cholesky_val = pt.diag(diag_input**0.5)
946+
return [cholesky_val]
948947

949948
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
950949
inputs_or_none = _find_diag_from_eye_mul(inputs)
@@ -962,8 +961,6 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
962961
# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
963962
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
964963
if non_eye_input.type.broadcastable[-2:] == (False, False):
965-
# For Matrix
966-
return [eye_input * (non_eye_input.diagonal(axis1=-1, axis2=-2) ** 0.5)]
967-
else:
968-
# For Vector or Scalar
969-
return [eye_input * (non_eye_input**0.5)]
964+
non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2)
965+
966+
return [eye_input * (non_eye_input**0.5)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -807,17 +807,12 @@ def test_slogdet_kronecker_rewrite():
807807

808808
def test_cholesky_eye_rewrite():
809809
x = pt.eye(10)
810-
x_mat = pt.matrix("x")
811810
L = pt.linalg.cholesky(x)
812-
L_mat = pt.linalg.cholesky(x_mat)
813811
f_rewritten = function([], L, mode="FAST_RUN")
814-
f_rewritten_mat = function([x_mat], L_mat, mode="FAST_RUN")
815812
nodes = f_rewritten.maker.fgraph.apply_nodes
816-
nodes_mat = f_rewritten_mat.maker.fgraph.apply_nodes
817813

818814
# Rewrite Test
819815
assert not any(isinstance(node.op, Cholesky) for node in nodes)
820-
assert any(isinstance(node.op, Cholesky) for node in nodes_mat)
821816

822817
# Value Test
823818
x_test = np.eye(10)
@@ -891,3 +886,15 @@ def test_cholesky_diag_from_diag():
891886
atol=1e-3 if config.floatX == "float32" else 1e-8,
892887
rtol=1e-3 if config.floatX == "float32" else 1e-8,
893888
)
889+
890+
891+
def test_dont_apply_cholesky():
892+
x = pt.tensor("x", shape=(7, 7))
893+
y = pt.eye(7, k=-1) * x
894+
# Here, y is not a diagonal matrix because of k = -1
895+
z_cholesky = pt.linalg.cholesky(y)
896+
897+
# REWRITE TEST (should not be applied)
898+
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
899+
nodes = f_rewritten.maker.fgraph.apply_nodes
900+
assert any(isinstance(node.op, Cholesky) for node in nodes)

0 commit comments

Comments
 (0)