Skip to content

Commit 89d7e5b

Browse files
committed
Start implementing test_svd()
1 parent 0a3f3f3 commit 89d7e5b

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

array_api_tests/hypothesis_helpers.py

+5
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ def matrix_shapes(draw, stack_shapes=shapes):
130130

131131
square_matrix_shapes = matrix_shapes().filter(lambda shape: shape[-1] == shape[-2])
132132

133+
finite_matrices = xps.arrays(dtype=xps.floating_dtypes(),
134+
shape=matrix_shapes(),
135+
elements=dict(allow_nan=False,
136+
allow_infinity=False))
137+
133138
def mutually_broadcastable_shapes(num_shapes: int) -> SearchStrategy[Tuple[Tuple]]:
134139
return (
135140
xps.mutually_broadcastable_shapes(num_shapes)

array_api_tests/test_linalg.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
positive_definite_matrices, MAX_ARRAY_SIZE,
2424
invertible_matrices, two_mutual_arrays,
2525
mutually_promotable_dtypes, one_d_shapes,
26-
two_mutually_broadcastable_shapes, SQRT_MAX_ARRAY_SIZE)
26+
two_mutually_broadcastable_shapes,
27+
SQRT_MAX_ARRAY_SIZE, finite_matrices)
2728
from .pytest_helpers import raises
2829
from . import dtype_helpers as dh
2930

@@ -476,12 +477,31 @@ def test_solve(x1, x2):
476477
pass
477478

478479
@given(
479-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
480-
kw=kwargs(full_matrices=todo)
480+
x=finite_matrices,
481+
kw=kwargs(full_matrices=booleans())
481482
)
482483
def test_svd(x, kw):
483-
# res = linalg.svd(x, **kw)
484-
pass
484+
res = linalg.svd(x, **kw)
485+
full_matrices = kw.get('full_matrices', True)
486+
487+
*stack, M, N = x.shape
488+
K = min(M, N)
489+
490+
_test_namedtuple(res, ['u', 's', 'vh'], 'svd')
491+
492+
u, s, vh = res
493+
494+
assert u.dtype == x.dtype, "svd().u did not return the correct dtype"
495+
assert s.dtype == x.dtype, "svd().s did not return the correct dtype"
496+
assert vh.dtype == x.dtype, "svd().vh did not return the correct dtype"
497+
498+
assert s.shape == (*stack, K)
499+
if full_matrices:
500+
assert u.shape == (*stack, M, M)
501+
assert vh.shape == (*stack, N, N)
502+
else:
503+
assert u.shape == (*stack, M, K)
504+
assert vh.shape == (*stack, K, N)
485505

486506
@given(
487507
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),

0 commit comments

Comments
 (0)