Skip to content

Commit 5fc76c2

Browse files
tanish1729jessegrabowski
authored andcommitted
added test for batched case and more cases of not applying rewrite
1 parent c54f3c1 commit 5fc76c2

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,5 +962,7 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
962962
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
963963
if non_eye_input.type.broadcastable[-2:] == (False, False):
964964
non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2)
965+
if eye_input.type.ndim > 2:
966+
non_eye_input = pt.shape_padaxis(non_eye_input, -2)
965967

966968
return [eye_input * (non_eye_input**0.5)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -829,8 +829,8 @@ def test_cholesky_eye_rewrite():
829829

830830
@pytest.mark.parametrize(
831831
"shape",
832-
[(), (7,), (7, 7)],
833-
ids=["scalar", "vector", "matrix"],
832+
[(), (7,), (7, 7), (5, 7, 7)],
833+
ids=["scalar", "vector", "matrix", "batched"],
834834
)
835835
def test_cholesky_diag_from_eye_mul(shape):
836836
# Initializing x based on scalar/vector/matrix
@@ -888,13 +888,21 @@ def test_cholesky_diag_from_diag():
888888
)
889889

890890

891-
def test_dont_apply_cholesky():
891+
def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
892+
# Case 1 : y is not a diagonal matrix because of k = -1
892893
x = pt.tensor("x", shape=(7, 7))
893894
y = pt.eye(7, k=-1) * x
894-
# Here, y is not a diagonal matrix because of k = -1
895895
z_cholesky = pt.linalg.cholesky(y)
896896

897897
# REWRITE TEST (should not be applied)
898898
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
899899
nodes = f_rewritten.maker.fgraph.apply_nodes
900900
assert any(isinstance(node.op, Cholesky) for node in nodes)
901+
902+
# Case 2 : eye is degenerate
903+
x = pt.scalar("x")
904+
y = pt.eye(1) * x
905+
z_cholesky = pt.linalg.cholesky(y)
906+
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
907+
nodes = f_rewritten.maker.fgraph.apply_nodes
908+
assert any(isinstance(node.op, Cholesky) for node in nodes)

0 commit comments

Comments
 (0)