Skip to content

Commit 44c13d9

Browse files
tanish1729jessegrabowski
authored andcommitted
added test for batched case and more cases of not applying rewrite
1 parent c0892e0 commit 44c13d9

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,5 +893,7 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
893893
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
894894
if non_eye_input.type.broadcastable[-2:] == (False, False):
895895
non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2)
896+
if eye_input.type.ndim > 2:
897+
non_eye_input = pt.shape_padaxis(non_eye_input, -2)
896898

897899
return [eye_input * (non_eye_input**0.5)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -777,8 +777,8 @@ def test_cholesky_eye_rewrite():
777777

778778
@pytest.mark.parametrize(
779779
"shape",
780-
[(), (7,), (7, 7)],
781-
ids=["scalar", "vector", "matrix"],
780+
[(), (7,), (7, 7), (5, 7, 7)],
781+
ids=["scalar", "vector", "matrix", "batched"],
782782
)
783783
def test_cholesky_diag_from_eye_mul(shape):
784784
# Initializing x based on scalar/vector/matrix
@@ -836,13 +836,21 @@ def test_cholesky_diag_from_diag():
836836
)
837837

838838

839-
def test_dont_apply_cholesky():
839+
def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
840+
# Case 1 : y is not a diagonal matrix because of k = -1
840841
x = pt.tensor("x", shape=(7, 7))
841842
y = pt.eye(7, k=-1) * x
842-
# Here, y is not a diagonal matrix because of k = -1
843843
z_cholesky = pt.linalg.cholesky(y)
844844

845845
# REWRITE TEST (should not be applied)
846846
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
847847
nodes = f_rewritten.maker.fgraph.apply_nodes
848848
assert any(isinstance(node.op, Cholesky) for node in nodes)
849+
850+
# Case 2 : eye is degenerate
851+
x = pt.scalar("x")
852+
y = pt.eye(1) * x
853+
z_cholesky = pt.linalg.cholesky(y)
854+
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
855+
nodes = f_rewritten.maker.fgraph.apply_nodes
856+
assert any(isinstance(node.op, Cholesky) for node in nodes)

0 commit comments

Comments
 (0)