Skip to content

Commit a4d419f

Browse files
committed
Update linalg tests to use assert_dtype and assert_shape helpers
1 parent 5ceb81d commit a4d419f

File tree

1 file changed

+100
-50
lines changed

1 file changed

+100
-50
lines changed

array_api_tests/test_linalg.py

Lines changed: 100 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,9 @@ def _test_namedtuple(res, fields, func_name):
126126
def test_cholesky(x, kw):
127127
res = linalg.cholesky(x, **kw)
128128

129-
assert res.shape == x.shape, "cholesky() did not return the correct shape"
130-
assert res.dtype == x.dtype, "cholesky() did not return the correct dtype"
129+
ph.assert_dtype("cholesky", in_dtype=x.dtype, out_dtype=res.dtype)
130+
ph.assert_result_shape("cholesky", in_shapes=[x.shape],
131+
out_shape=res.shape, expected=x.shape)
131132

132133
_test_stacks(linalg.cholesky, x, **kw, res=res)
133134

@@ -192,7 +193,7 @@ def test_cross(x1_x2_kw):
192193

193194
ph.assert_dtype("cross", in_dtype=[x1.dtype, x2.dtype],
194195
out_dtype=res.dtype)
195-
ph.assert_shape("cross", out_shape=res.shape, expected=broadcasted_shape)
196+
ph.assert_result_shape("cross", in_shapes=[x1.shape, x2.shape], out_shape=res.shape, expected=broadcasted_shape)
196197

197198
def exact_cross(a, b):
198199
assert a.shape == b.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
@@ -214,8 +215,9 @@ def exact_cross(a, b):
214215
def test_det(x):
215216
res = linalg.det(x)
216217

217-
assert res.dtype == x.dtype, "det() did not return the correct dtype"
218-
assert res.shape == x.shape[:-2], "det() did not return the correct shape"
218+
ph.assert_dtype("det", in_dtype=x.dtype, out_dtype=res.dtype)
219+
ph.assert_result_shape("det", in_shapes=[x.shape], out_shape=res.shape,
220+
expected=x.shape[:-2])
219221

220222
_test_stacks(linalg.det, x, res=res, dims=0)
221223

@@ -231,7 +233,7 @@ def test_det(x):
231233
def test_diagonal(x, kw):
232234
res = linalg.diagonal(x, **kw)
233235

234-
assert res.dtype == x.dtype, "diagonal() returned the wrong dtype"
236+
ph.assert_dtype("diagonal", in_dtype=x.dtype, out_dtype=res.dtype)
235237

236238
n, m = x.shape[-2:]
237239
offset = kw.get('offset', 0)
@@ -245,7 +247,9 @@ def test_diagonal(x, kw):
245247
else:
246248
diag_size = min(n, m, max(m - offset, 0))
247249

248-
assert res.shape == (*x.shape[:-2], diag_size), "diagonal() returned the wrong shape"
250+
expected_shape = (*x.shape[:-2], diag_size)
251+
ph.assert_result_shape("diagonal", in_shapes=[x.shape],
252+
out_shape=res.shape, expected=expected_shape)
249253

250254
def true_diag(x_stack, offset=0):
251255
if offset >= 0:
@@ -266,11 +270,18 @@ def test_eigh(x):
266270
eigenvalues = res.eigenvalues
267271
eigenvectors = res.eigenvectors
268272

269-
assert eigenvalues.dtype == x.dtype, "eigh().eigenvalues did not return the correct dtype"
270-
assert eigenvalues.shape == x.shape[:-1], "eigh().eigenvalues did not return the correct shape"
273+
ph.assert_dtype("eigh", in_dtype=x.dtype, out_dtype=eigenvalues.dtype,
274+
expected=x.dtype, repr_name="eigenvalues.dtype")
275+
ph.assert_result_shape("eigh", in_shapes=[x.shape],
276+
out_shape=eigenvalues.shape,
277+
expected=x.shape[:-1],
278+
repr_name="eigenvalues.shape")
271279

272-
assert eigenvectors.dtype == x.dtype, "eigh().eigenvectors did not return the correct dtype"
273-
assert eigenvectors.shape == x.shape, "eigh().eigenvectors did not return the correct shape"
280+
ph.assert_dtype("eigh", in_dtype=x.dtype, out_dtype=eigenvectors.dtype,
281+
expected=x.dtype, repr_name="eigenvectors.dtype")
282+
ph.assert_result_shape("eigh", in_shapes=[x.shape],
283+
out_shape=eigenvectors.shape, expected=x.shape,
284+
repr_name="eigenvectors.shape")
274285

275286
# Note: _test_stacks here is only testing the shape and dtype. The actual
276287
# eigenvalues and eigenvectors may not be equal at all, since there is not
@@ -292,8 +303,9 @@ def test_eigh(x):
292303
def test_eigvalsh(x):
293304
res = linalg.eigvalsh(x)
294305

295-
assert res.dtype == x.dtype, "eigvalsh() did not return the correct dtype"
296-
assert res.shape == x.shape[:-1], "eigvalsh() did not return the correct shape"
306+
ph.assert_dtype("eigvalsh", in_dtype=x.dtype, out_dtype=res.dtype)
307+
ph.assert_result_shape("eigvalsh", in_shapes=[x.shape],
308+
out_shape=res.shape, expected=x.shape[:-1])
297309

298310
# Note: _test_stacks here is only testing the shape and dtype. The actual
299311
# eigenvalues may not be equal at all, since there is not requirements or
@@ -311,8 +323,9 @@ def test_eigvalsh(x):
311323
def test_inv(x):
312324
res = linalg.inv(x)
313325

314-
assert res.shape == x.shape, "inv() did not return the correct shape"
315-
assert res.dtype == x.dtype, "inv() did not return the correct dtype"
326+
ph.assert_dtype("inv", in_dtype=x.dtype, out_dtype=res.dtype)
327+
ph.assert_result_shape("inv", in_shapes=[x.shape], out_shape=res.shape,
328+
expected=x.shape)
316329

317330
_test_stacks(linalg.inv, x, res=res)
318331

@@ -339,18 +352,24 @@ def test_matmul(x1, x2):
339352
ph.assert_dtype("matmul", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype)
340353

341354
if len(x1.shape) == len(x2.shape) == 1:
342-
assert res.shape == ()
355+
ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape],
356+
out_shape=res.shape, expected=())
343357
elif len(x1.shape) == 1:
344-
assert res.shape == x2.shape[:-2] + x2.shape[-1:]
358+
ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape],
359+
out_shape=res.shape,
360+
expected=x2.shape[:-2] + x2.shape[-1:])
345361
_test_stacks(_array_module.matmul, x1, x2, res=res, dims=1,
346362
matrix_axes=[(0,), (-2, -1)], res_axes=[-1])
347363
elif len(x2.shape) == 1:
348-
assert res.shape == x1.shape[:-1]
364+
ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape],
365+
out_shape=res.shape, expected=x1.shape[:-1])
349366
_test_stacks(_array_module.matmul, x1, x2, res=res, dims=1,
350367
matrix_axes=[(-2, -1), (0,)], res_axes=[-1])
351368
else:
352369
stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
353-
assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1])
370+
ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape],
371+
out_shape=res.shape,
372+
expected=stack_shape + (x1.shape[-2], x2.shape[-1]))
354373
_test_stacks(_array_module.matmul, x1, x2, res=res)
355374

356375
@pytest.mark.xp_extension('linalg')
@@ -370,8 +389,9 @@ def test_matrix_norm(x, kw):
370389
expected_shape = x.shape[:-2] + (1, 1)
371390
else:
372391
expected_shape = x.shape[:-2]
373-
assert res.shape == expected_shape, f"matrix_norm({keepdims=}) did not return the correct shape"
374-
assert res.dtype == x.dtype, "matrix_norm() did not return the correct dtype"
392+
ph.assert_dtype("matrix_norm", in_dtype=x.dtype, out_dtype=res.dtype)
393+
ph.assert_result_shape("matrix_norm", in_shapes=[x.shape],
394+
out_shape=res.shape, expected=expected_shape)
375395

376396
_test_stacks(linalg.matrix_norm, x, **kw, dims=2 if keepdims else 0,
377397
res=res)
@@ -388,8 +408,9 @@ def test_matrix_norm(x, kw):
388408
def test_matrix_power(x, n):
389409
res = linalg.matrix_power(x, n)
390410

391-
assert res.shape == x.shape, "matrix_power() did not return the correct shape"
392-
assert res.dtype == x.dtype, "matrix_power() did not return the correct dtype"
411+
ph.assert_dtype("matrix_power", in_dtype=x.dtype, out_dtype=res.dtype)
412+
ph.assert_result_shape("matrix_power", in_shapes=[x.shape],
413+
out_shape=res.shape, expected=x.shape)
393414

394415
if n == 0:
395416
true_val = lambda x: _array_module.eye(x.shape[0], dtype=x.dtype)
@@ -419,8 +440,9 @@ def test_matrix_transpose(x):
419440
shape = list(x.shape)
420441
shape[-1], shape[-2] = shape[-2], shape[-1]
421442
shape = tuple(shape)
422-
assert res.shape == shape, "matrix_transpose() did not return the correct shape"
423-
assert res.dtype == x.dtype, "matrix_transpose() did not return the correct dtype"
443+
ph.assert_dtype("matrix_transpose", in_dtype=x.dtype, out_dtype=res.dtype)
444+
ph.assert_result_shape("matrix_transpose", in_shapes=[x.shape],
445+
out_shape=res.shape, expected=shape)
424446

425447
_test_stacks(_array_module.matrix_transpose, x, res=res, true_val=true_val)
426448

@@ -435,8 +457,9 @@ def test_outer(x1, x2):
435457
res = linalg.outer(x1, x2)
436458

437459
shape = (x1.shape[0], x2.shape[0])
438-
assert res.shape == shape, "outer() did not return the correct shape"
439-
assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "outer() did not return the correct dtype"
460+
ph.assert_dtype("outer", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype)
461+
ph.assert_result_shape("outer", in_shapes=[x1.shape, x2.shape],
462+
out_shape=res.shape, expected=shape)
440463

441464
if 0 in shape:
442465
true_res = _array_module.empty(shape, dtype=res.dtype)
@@ -472,17 +495,23 @@ def test_qr(x, kw):
472495
Q = res.Q
473496
R = res.R
474497

475-
assert Q.dtype == x.dtype, "qr().Q did not return the correct dtype"
498+
ph.assert_dtype("qr", in_dtype=x.dtype, out_dtype=Q.dtype,
499+
expected=x.dtype, repr_name="Q.dtype")
476500
if mode == 'complete':
477-
assert Q.shape == x.shape[:-2] + (M, M), "qr().Q did not return the correct shape"
501+
expected_Q_shape = x.shape[:-2] + (M, M)
478502
else:
479-
assert Q.shape == x.shape[:-2] + (M, K), "qr().Q did not return the correct shape"
503+
expected_Q_shape = x.shape[:-2] + (M, K)
504+
ph.assert_result_shape("qr", in_shapes=[x.shape], out_shape=Q.shape,
505+
expected=expected_Q_shape, repr_name="Q.shape")
480506

481-
assert R.dtype == x.dtype, "qr().R did not return the correct dtype"
507+
ph.assert_dtype("qr", in_dtype=x.dtype, out_dtype=R.dtype,
508+
expected=x.dtype, repr_name="R.dtype")
482509
if mode == 'complete':
483-
assert R.shape == x.shape[:-2] + (M, N), "qr().R did not return the correct shape"
510+
expected_R_shape = x.shape[:-2] + (M, N)
484511
else:
485-
assert R.shape == x.shape[:-2] + (K, N), "qr().R did not return the correct shape"
512+
expected_R_shape = x.shape[:-2] + (K, N)
513+
ph.assert_result_shape("qr", in_shapes=[x.shape], out_shape=R.shape,
514+
expected=expected_R_shape, repr_name="R.shape")
486515

487516
_test_stacks(lambda x: linalg.qr(x, **kw).Q, x, res=Q)
488517
_test_stacks(lambda x: linalg.qr(x, **kw).R, x, res=R)
@@ -505,14 +534,17 @@ def test_slogdet(x):
505534

506535
ph.assert_dtype("slogdet", in_dtype=x.dtype, out_dtype=sign.dtype,
507536
expected=x.dtype, repr_name="sign.dtype")
508-
ph.assert_shape("slogdet", out_shape=sign.shape, expected=x.shape[:-2],
509-
repr_name="sign.shape")
537+
ph.assert_result_shape("slogdet", in_shapes=[x.shape],
538+
out_shape=sign.shape,
539+
expected=x.shape[:-2],
540+
repr_name="sign.shape")
510541
expected_dtype = dh.as_real_dtype(x.dtype)
511542
ph.assert_dtype("slogdet", in_dtype=x.dtype, out_dtype=logabsdet.dtype,
512543
expected=expected_dtype, repr_name="logabsdet.dtype")
513-
ph.assert_shape("slogdet", out_shape=logabsdet.shape,
514-
expected=x.shape[:-2],
515-
repr_name="logabsdet.shape")
544+
ph.assert_result_shape("slogdet", in_shapes=[x.shape],
545+
out_shape=logabsdet.shape,
546+
expected=x.shape[:-2],
547+
repr_name="logabsdet.shape")
516548

517549
_test_stacks(lambda x: linalg.slogdet(x).sign, x,
518550
res=sign, dims=0)
@@ -584,17 +616,31 @@ def test_svd(x, kw):
584616

585617
U, S, Vh = res
586618

587-
assert U.dtype == x.dtype, "svd().U did not return the correct dtype"
588-
assert S.dtype == x.dtype, "svd().S did not return the correct dtype"
589-
assert Vh.dtype == x.dtype, "svd().Vh did not return the correct dtype"
619+
ph.assert_dtype("svd", in_dtype=x.dtype, out_dtype=U.dtype,
620+
expected=x.dtype, repr_name="U.dtype")
621+
ph.assert_dtype("svd", in_dtype=x.dtype, out_dtype=S.dtype,
622+
expected=x.dtype, repr_name="S.dtype")
623+
ph.assert_dtype("svd", in_dtype=x.dtype, out_dtype=Vh.dtype,
624+
expected=x.dtype, repr_name="Vh.dtype")
590625

591626
if full_matrices:
592-
assert U.shape == (*stack, M, M), "svd().U did not return the correct shape"
593-
assert Vh.shape == (*stack, N, N), "svd().Vh did not return the correct shape"
627+
expected_U_shape = (*stack, M, M)
628+
expected_Vh_shape = (*stack, N, N)
594629
else:
595-
assert U.shape == (*stack, M, K), "svd(full_matrices=False).U did not return the correct shape"
596-
assert Vh.shape == (*stack, K, N), "svd(full_matrices=False).Vh did not return the correct shape"
597-
assert S.shape == (*stack, K), "svd().S did not return the correct shape"
630+
expected_U_shape = (*stack, M, K)
631+
expected_Vh_shape = (*stack, K, N)
632+
ph.assert_result_shape("svd", in_shapes=[x.shape],
633+
out_shape=U.shape,
634+
expected=expected_U_shape,
635+
repr_name="U.shape")
636+
ph.assert_result_shape("svd", in_shapes=[x.shape],
637+
out_shape=Vh.shape,
638+
expected=expected_Vh_shape,
639+
repr_name="Vh.shape")
640+
ph.assert_result_shape("svd", in_shapes=[x.shape],
641+
out_shape=S.shape,
642+
expected=(*stack, K),
643+
repr_name="S.shape")
598644

599645
# The values of s must be sorted from largest to smallest
600646
if K >= 1:
@@ -614,8 +660,11 @@ def test_svdvals(x):
614660
*stack, M, N = x.shape
615661
K = min(M, N)
616662

617-
assert res.dtype == x.dtype, "svdvals() did not return the correct dtype"
618-
assert res.shape == (*stack, K), "svdvals() did not return the correct shape"
663+
ph.assert_dtype("svdvals", in_dtype=x.dtype, out_dtype=res.dtype,
664+
expected=x.dtype)
665+
ph.assert_result_shape("svdvals", in_shapes=[x.shape],
666+
out_shape=res.shape,
667+
expected=(*stack, K))
619668

620669
# SVD values must be sorted from largest to smallest
621670
assert _array_module.all(res[..., :-1] >= res[..., 1:]), "svdvals() values are not sorted from largest to smallest"
@@ -753,7 +802,7 @@ def test_trace(x, kw):
753802
# assert res.dtype == x.dtype, "trace() returned the wrong dtype"
754803

755804
n, m = x.shape[-2:]
756-
assert res.shape == x.shape[:-2], "trace() returned the wrong shape"
805+
ph.assert_result_shape('trace', x.shape, res.shape, expected=x.shape[:-2])
757806

758807
def true_trace(x_stack, offset=0):
759808
# Note: the spec does not specify that offset must be within the
@@ -799,7 +848,8 @@ def test_vecdot(x1, x2, data):
799848

800849
ph.assert_dtype("vecdot", in_dtype=[x1.dtype, x2.dtype],
801850
out_dtype=res.dtype)
802-
ph.assert_shape("vecdot", out_shape=res.shape, expected=expected_shape)
851+
ph.assert_result_shape("vecdot", in_shapes=[x1.shape, x2.shape],
852+
out_shape=res.shape, expected=expected_shape)
803853

804854
if x1.dtype in dh.int_dtypes:
805855
def true_val(x, y, axis=-1):

0 commit comments

Comments
 (0)