Skip to content

Commit 14ac81a

Browse files
committed
Add unvectorized markers to test_linalg.py
1 parent a168e5a commit 14ac81a

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

array_api_tests/test_linalg.py

+25
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def _test_namedtuple(res, fields, func_name):
119119
assert hasattr(res, field), f"{func_name}() result namedtuple doesn't have the '{field}' field"
120120
assert res[i] is getattr(res, field), f"{func_name}() result namedtuple '{field}' field is not in position {i}"
121121

122+
@pytest.mark.unvectorized
122123
@pytest.mark.xp_extension('linalg')
123124
@given(
124125
x=positive_definite_matrices(),
@@ -175,6 +176,7 @@ def cross_args(draw, dtype_objects=dh.real_dtypes):
175176
)
176177
return draw(arrays1), draw(arrays2), kw
177178

179+
@pytest.mark.unvectorized
178180
@pytest.mark.xp_extension('linalg')
179181
@given(
180182
cross_args()
@@ -209,6 +211,7 @@ def exact_cross(a, b):
209211
# vectors.
210212
_test_stacks(linalg.cross, x1, x2, dims=1, matrix_axes=(axis,), res=res, true_val=exact_cross)
211213

214+
@pytest.mark.unvectorized
212215
@pytest.mark.xp_extension('linalg')
213216
@given(
214217
x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes),
@@ -224,6 +227,7 @@ def test_det(x):
224227

225228
# TODO: Test that res actually corresponds to the determinant of x
226229

230+
@pytest.mark.unvectorized
227231
@pytest.mark.xp_extension('linalg')
228232
@given(
229233
x=arrays(dtype=xps.scalar_dtypes(), shape=matrix_shapes()),
@@ -261,6 +265,7 @@ def true_diag(x_stack, offset=0):
261265

262266
_test_stacks(linalg.diagonal, x, **kw, res=res, dims=1, true_val=true_diag)
263267

268+
@pytest.mark.unvectorized
264269
@pytest.mark.xp_extension('linalg')
265270
@given(x=symmetric_matrices(finite=True))
266271
def test_eigh(x):
@@ -299,6 +304,7 @@ def test_eigh(x):
299304
# TODO: Test that res actually corresponds to the eigenvalues and
300305
# eigenvectors of x
301306

307+
@pytest.mark.unvectorized
302308
@pytest.mark.xp_extension('linalg')
303309
@given(x=symmetric_matrices(finite=True))
304310
def test_eigvalsh(x):
@@ -319,6 +325,7 @@ def test_eigvalsh(x):
319325

320326
# TODO: Test that res actually corresponds to the eigenvalues of x
321327

328+
@pytest.mark.unvectorized
322329
@pytest.mark.xp_extension('linalg')
323330
@given(x=invertible_matrices())
324331
def test_inv(x):
@@ -372,19 +379,22 @@ def _test_matmul(namespace, x1, x2):
372379
expected=stack_shape + (x1.shape[-2], x2.shape[-1]))
373380
_test_stacks(matmul, x1, x2, res=res)
374381

382+
@pytest.mark.unvectorized
375383
@pytest.mark.xp_extension('linalg')
376384
@given(
377385
*two_mutual_arrays(dh.real_dtypes)
378386
)
379387
def test_linalg_matmul(x1, x2):
380388
return _test_matmul(linalg, x1, x2)
381389

390+
@pytest.mark.unvectorized
382391
@given(
383392
*two_mutual_arrays(dh.real_dtypes)
384393
)
385394
def test_matmul(x1, x2):
386395
return _test_matmul(_array_module, x1, x2)
387396

397+
@pytest.mark.unvectorized
388398
@pytest.mark.xp_extension('linalg')
389399
@given(
390400
x=finite_matrices(),
@@ -410,6 +420,7 @@ def test_matrix_norm(x, kw):
410420
res=res)
411421

412422
matrix_power_n = shared(integers(-100, 100), key='matrix_power n')
423+
@pytest.mark.unvectorized
413424
@pytest.mark.xp_extension('linalg')
414425
@given(
415426
# Generate any square matrix if n >= 0 but only invertible matrices if n < 0
@@ -433,6 +444,7 @@ def test_matrix_power(x, n):
433444
func = lambda x: linalg.matrix_power(x, n)
434445
_test_stacks(func, x, res=res, true_val=true_val)
435446

447+
@pytest.mark.unvectorized
436448
@pytest.mark.xp_extension('linalg')
437449
@given(
438450
x=finite_matrices(shape=rtol_shared_matrix_shapes),
@@ -457,13 +469,15 @@ def _test_matrix_transpose(namespace, x):
457469

458470
_test_stacks(matrix_transpose, x, res=res, true_val=true_val)
459471

472+
@pytest.mark.unvectorized
460473
@pytest.mark.xp_extension('linalg')
461474
@given(
462475
x=arrays(dtype=xps.scalar_dtypes(), shape=matrix_shapes()),
463476
)
464477
def test_linalg_matrix_transpose(x):
465478
return _test_matrix_transpose(linalg, x)
466479

480+
@pytest.mark.unvectorized
467481
@given(
468482
x=arrays(dtype=xps.scalar_dtypes(), shape=matrix_shapes()),
469483
)
@@ -503,6 +517,7 @@ def test_outer(x1, x2):
503517
def test_pinv(x, kw):
504518
linalg.pinv(x, **kw)
505519

520+
@pytest.mark.unvectorized
506521
@pytest.mark.xp_extension('linalg')
507522
@given(
508523
x=arrays(dtype=all_floating_dtypes(), shape=matrix_shapes()),
@@ -545,6 +560,7 @@ def test_qr(x, kw):
545560
# Check that R is upper-triangular.
546561
assert_exactly_equal(R, _array_module.triu(R))
547562

563+
@pytest.mark.unvectorized
548564
@pytest.mark.xp_extension('linalg')
549565
@given(
550566
x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes),
@@ -617,6 +633,7 @@ def _x2_shapes(draw):
617633
x2 = arrays(shape=x2_shapes, dtype=mutual_dtypes.map(lambda pair: pair[1]))
618634
return x1, x2
619635

636+
@pytest.mark.unvectorized
620637
@pytest.mark.xp_extension('linalg')
621638
@given(*solve_args())
622639
def test_solve(x1, x2):
@@ -635,6 +652,7 @@ def test_solve(x1, x2):
635652
ph.assert_result_shape("solve", in_shapes=[x1.shape, x2.shape],
636653
out_shape=res.shape, expected=expected_shape)
637654

655+
@pytest.mark.unvectorized
638656
@pytest.mark.xp_extension('linalg')
639657
@given(
640658
x=finite_matrices(),
@@ -685,6 +703,7 @@ def test_svd(x, kw):
685703
_test_stacks(lambda x: linalg.svd(x, **kw).S, x, dims=1, res=S)
686704
_test_stacks(lambda x: linalg.svd(x, **kw).Vh, x, res=Vh)
687705

706+
@pytest.mark.unvectorized
688707
@pytest.mark.xp_extension('linalg')
689708
@given(
690709
x=finite_matrices(),
@@ -818,6 +837,7 @@ def _test_tensordot(namespace, x1, x2, kw):
818837
expected=result_shape)
819838
_test_tensordot_stacks(x1, x2, kw, res)
820839

840+
@pytest.mark.unvectorized
821841
@pytest.mark.xp_extension('linalg')
822842
@given(
823843
*two_mutual_arrays(dh.numeric_dtypes, two_shapes=tensordot_shapes()),
@@ -826,13 +846,15 @@ def _test_tensordot(namespace, x1, x2, kw):
826846
def test_linalg_tensordot(x1, x2, kw):
827847
_test_tensordot(linalg, x1, x2, kw)
828848

849+
@pytest.mark.unvectorized
829850
@given(
830851
*two_mutual_arrays(dh.numeric_dtypes, two_shapes=tensordot_shapes()),
831852
tensordot_kw,
832853
)
833854
def test_tensordot(x1, x2, kw):
834855
_test_tensordot(_array_module, x1, x2, kw)
835856

857+
@pytest.mark.unvectorized
836858
@pytest.mark.xp_extension('linalg')
837859
@given(
838860
x=arrays(dtype=xps.numeric_dtypes(), shape=matrix_shapes()),
@@ -910,6 +932,7 @@ def true_val(x, y, axis=-1):
910932
matrix_axes=(axis,), true_val=true_val)
911933

912934

935+
@pytest.mark.unvectorized
913936
@pytest.mark.xp_extension('linalg')
914937
@given(
915938
*two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)),
@@ -918,6 +941,7 @@ def true_val(x, y, axis=-1):
918941
def test_linalg_vecdot(x1, x2, data):
919942
_test_vecdot(linalg, x1, x2, data)
920943

944+
@pytest.mark.unvectorized
921945
@given(
922946
*two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)),
923947
data(),
@@ -929,6 +953,7 @@ def test_vecdot(x1, x2, data):
929953
# spec, so we just limit to reasonable values here.
930954
max_ord = 100
931955

956+
@pytest.mark.unvectorized
932957
@pytest.mark.xp_extension('linalg')
933958
@given(
934959
x=arrays(dtype=all_floating_dtypes(), shape=shapes(min_side=1)),

0 commit comments

Comments
 (0)