Skip to content

Commit d85d90d

Browse files
committed
test concrete values and tranpose op
1 parent e1e77b0 commit d85d90d

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

tests/sandbox/linalg/test_linalg.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import numpy.linalg
3+
import scipy.linalg
34

45
import pytensor
56
from pytensor import function
@@ -157,21 +158,53 @@ def test_matrix_inverse_solve():
157158

158159
def test_cholesky_dot_lower():
159160
cholesky_lower = Cholesky(lower=True)
161+
cholesky_upper = Cholesky(lower=False)
160162

161163
L = matrix("L")
162164
L.tag.lower_triangular = True
163165

164166
C = cholesky_lower(L.dot(L.T))
165167
f = pytensor.function([L], C)
168+
166169
if config.mode != "FAST_COMPILE":
167170
assert (f.maker.fgraph.outputs[0] == f.maker.fgraph.inputs[0]) or (
168171
(o := f.maker.fgraph.outputs[0].owner)
169172
and isinstance(o.op, (DeepCopyOp, ViewOp))
170173
and o.inputs[0] == f.maker.fgraph.inputs[0]
171174
)
172175

176+
# Test some concrete value through f:
177+
Lv = np.array([[2, 0], [1, 4]])
178+
assert np.all(
179+
np.isclose(
180+
scipy.linalg.cholesky(np.dot(Lv, Lv.T), lower=True),
181+
f(Lv),
182+
)
183+
)
184+
185+
# Test upper decomposition factors down to a transpose
186+
C = cholesky_upper(L.dot(L.T))
187+
f = pytensor.function([L], C)
188+
if config.mode != "FAST_COMPILE":
189+
assert (
190+
(o1 := f.maker.fgraph.outputs[0].owner)
191+
and isinstance(o1.op, (DeepCopyOp, ViewOp))
192+
and (o2 := o1.inputs[0].owner)
193+
and isinstance(o2.op, DimShuffle)
194+
and o2.op.new_order == (1, 0)
195+
and o2.inputs[0] == f.maker.fgraph.inputs[0]
196+
)
197+
198+
assert np.all(
199+
np.isclose(
200+
scipy.linalg.cholesky(np.dot(Lv, Lv.T), lower=False),
201+
f(Lv),
202+
)
203+
)
204+
173205

174206
def test_cholesky_dot_upper():
207+
cholesky_lower = Cholesky(lower=True)
175208
cholesky_upper = Cholesky(lower=False)
176209

177210
U = matrix("U")
@@ -185,3 +218,32 @@ def test_cholesky_dot_upper():
185218
and isinstance(o.op, (DeepCopyOp, ViewOp))
186219
and o.inputs[0] == f.maker.fgraph.inputs[0]
187220
)
221+
222+
# Test some concrete value through f:
223+
Uv = np.array([[2, 1], [0, 4]])
224+
assert np.all(
225+
np.isclose(
226+
scipy.linalg.cholesky(np.dot(Uv.T, Uv), lower=False),
227+
f(Uv),
228+
)
229+
)
230+
231+
# Test lower decomposition factors down to a transpose
232+
C = cholesky_lower(U.T.dot(U))
233+
f = pytensor.function([U], C)
234+
if config.mode != "FAST_COMPILE":
235+
assert (
236+
(o1 := f.maker.fgraph.outputs[0].owner)
237+
and isinstance(o1.op, (DeepCopyOp, ViewOp))
238+
and (o2 := o1.inputs[0].owner)
239+
and isinstance(o2.op, DimShuffle)
240+
and o2.op.new_order == (1, 0)
241+
and o2.inputs[0] == f.maker.fgraph.inputs[0]
242+
)
243+
244+
assert np.all(
245+
np.isclose(
246+
scipy.linalg.cholesky(np.dot(Uv.T, Uv), lower=True),
247+
f(Uv),
248+
)
249+
)

0 commit comments

Comments
 (0)