Skip to content

Commit b75e0f6

Browse files
Test SVD gradients
1 parent aba800f commit b75e0f6

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tests/tensor/test_nlinalg.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,13 @@ def test_grad(self, compute_uv, full_matrices, shape, batched):
247247
pytensor.grad(s.sum(), self.A)
248248

249249
elif compute_uv:
250+
251+
def svd_fn(A):
252+
U, s, V = svd(A, compute_uv=compute_uv, full_matrices=full_matrices)
253+
return U.sum() + s.sum() + V.sum()
254+
250255
utt.verify_grad(
251-
partial(svd, compute_uv=compute_uv, full_matrices=full_matrices),
256+
svd_fn,
252257
[A_v],
253258
rng=rng,
254259
)

0 commit comments

Comments
 (0)