Skip to content

Commit 82a1e95

Browse files
committed
add test for cholesky rewrite
1 parent 6b43b43 commit 82a1e95

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

tests/sandbox/linalg/test_linalg.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,27 @@ def test_matrix_inverse_solve():
152152
node = matrix_inverse(A).dot(b).owner
153153
[out] = inv_as_solve.transform(None, node)
154154
assert isinstance(out.owner.op, Solve)
155+
156+
157+
def test_cholesky_dot_lower():
158+
cholesky_lower = Cholesky(lower=True)
159+
160+
L = matrix("L")
161+
L.tag.lower_triangular = True
162+
163+
C = cholesky_lower(L.dot(L.T))
164+
f = pytensor.function([L], C)
165+
if config.mode != "FAST_COMPILE":
166+
assert f.maker.fgraph.outputs[0].name == "L"
167+
168+
169+
def test_cholesky_dot_upper():
170+
cholesky_upper = Cholesky(lower=False)
171+
172+
U = matrix("U")
173+
U.tag.upper_triangular = True
174+
175+
C = cholesky_upper(U.T.dot(U))
176+
f = pytensor.function([U], C)
177+
if config.mode != "FAST_COMPILE":
178+
assert f.maker.fgraph.outputs[0].name == "U"

0 commit comments

Comments
 (0)