Skip to content

Commit 544d4fc

Browse files
committed
Remove xp_extension() for top-level linalg tests
1 parent 996f11d commit 544d4fc

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

array_api_tests/test_linalg.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@
3434
from ._array_module import linalg
3535

3636

37-
pytestmark = [pytest.mark.xp_extension('linalg')]
38-
39-
4037
# Standin strategy for not yet implemented tests
4138
todo = none()
4239

@@ -78,6 +75,7 @@ def _test_namedtuple(res, fields, func_name):
7875
assert hasattr(res, field), f"{func_name}() result namedtuple doesn't have the '{field}' field"
7976
assert res[i] is getattr(res, field), f"{func_name}() result namedtuple '{field}' field is not in position {i}"
8077

78+
@pytest.mark.xp_extension('linalg')
8179
@given(
8280
x=positive_definite_matrices(),
8381
kw=kwargs(upper=booleans())
@@ -125,6 +123,7 @@ def cross_args(draw, dtype_objects=dh.numeric_dtypes):
125123
)
126124
return draw(arrays1), draw(arrays2), kw
127125

126+
@pytest.mark.xp_extension('linalg')
128127
@given(
129128
cross_args()
130129
)
@@ -163,6 +162,7 @@ def test_cross(x1_x2_kw):
163162
], dtype=res.dtype)
164163
assert_exactly_equal(res_stack, exact_cross)
165164

165+
@pytest.mark.xp_extension('linalg')
166166
@given(
167167
x=xps.arrays(dtype=xps.floating_dtypes(), shape=square_matrix_shapes),
168168
)
@@ -176,6 +176,7 @@ def test_det(x):
176176

177177
# TODO: Test that res actually corresponds to the determinant of x
178178

179+
@pytest.mark.xp_extension('linalg')
179180
@given(
180181
x=xps.arrays(dtype=dtypes, shape=matrix_shapes),
181182
# offset may produce an overflow if it is too large. Supporting offsets
@@ -210,6 +211,7 @@ def true_diag(x_stack):
210211

211212
_test_stacks(linalg.diagonal, x, **kw, res=res, dims=1, true_val=true_diag)
212213

214+
@pytest.mark.xp_extension('linalg')
213215
@given(x=symmetric_matrices(finite=True))
214216
def test_eigh(x):
215217
res = linalg.eigh(x)
@@ -233,6 +235,7 @@ def test_eigh(x):
233235
# TODO: Test that res actually corresponds to the eigenvalues and
234236
# eigenvectors of x
235237

238+
@pytest.mark.xp_extension('linalg')
236239
@given(x=symmetric_matrices(finite=True))
237240
def test_eigvalsh(x):
238241
res = linalg.eigvalsh(x)
@@ -246,6 +249,7 @@ def test_eigvalsh(x):
246249

247250
# TODO: Test that res actually corresponds to the eigenvalues of x
248251

252+
@pytest.mark.xp_extension('linalg')
249253
@given(x=invertible_matrices())
250254
def test_inv(x):
251255
res = linalg.inv(x)
@@ -290,6 +294,7 @@ def test_matmul(x1, x2):
290294
assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1])
291295
_test_stacks(linalg.matmul, x1, x2, res=res)
292296

297+
@pytest.mark.xp_extension('linalg')
293298
@given(
294299
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
295300
kw=kwargs(axis=todo, keepdims=todo, ord=todo)
@@ -299,6 +304,7 @@ def test_matrix_norm(x, kw):
299304
pass
300305

301306
matrix_power_n = shared(integers(-1000, 1000), key='matrix_power n')
307+
@pytest.mark.xp_extension('linalg')
302308
@given(
303309
# Generate any square matrix if n >= 0 but only invertible matrices if n < 0
304310
x=matrix_power_n.flatmap(lambda n: invertible_matrices() if n < 0 else
@@ -320,6 +326,7 @@ def test_matrix_power(x, n):
320326
func = lambda x: linalg.matrix_power(x, n)
321327
_test_stacks(func, x, res=res, true_val=true_val)
322328

329+
@pytest.mark.xp_extension('linalg')
323330
@given(
324331
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
325332
kw=kwargs(rtol=todo)
@@ -345,6 +352,7 @@ def test_matrix_transpose(x):
345352

346353
_test_stacks(linalg.matrix_transpose, x, res=res, true_val=true_val)
347354

355+
@pytest.mark.xp_extension('linalg')
348356
@given(
349357
*two_mutual_arrays(dtype_objs=dh.numeric_dtypes,
350358
two_shapes=tuples(one_d_shapes, one_d_shapes))
@@ -368,6 +376,7 @@ def test_outer(x1, x2):
368376

369377
assert_exactly_equal(res, true_res)
370378

379+
@pytest.mark.xp_extension('linalg')
371380
@given(
372381
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
373382
kw=kwargs(rtol=todo)
@@ -376,6 +385,7 @@ def test_pinv(x, kw):
376385
# res = linalg.pinv(x, **kw)
377386
pass
378387

388+
@pytest.mark.xp_extension('linalg')
379389
@given(
380390
x=xps.arrays(dtype=xps.floating_dtypes(), shape=matrix_shapes),
381391
kw=kwargs(mode=sampled_from(['reduced', 'complete']))
@@ -411,6 +421,7 @@ def test_qr(x, kw):
411421
# Check that r is upper-triangular.
412422
assert_exactly_equal(r, _array_module.triu(r))
413423

424+
@pytest.mark.xp_extension('linalg')
414425
@given(
415426
x=xps.arrays(dtype=xps.floating_dtypes(), shape=square_matrix_shapes),
416427
)
@@ -468,6 +479,7 @@ def x2_shapes(draw):
468479
x2 = xps.arrays(dtype=xps.floating_dtypes(), shape=x2_shapes())
469480
return x1, x2
470481

482+
@pytest.mark.xp_extension('linalg')
471483
@given(*solve_args())
472484
def test_solve(x1, x2):
473485
# TODO: solve() is currently ambiguous, in that some inputs can be
@@ -480,6 +492,7 @@ def test_solve(x1, x2):
480492
# res = linalg.solve(x1, x2)
481493
pass
482494

495+
@pytest.mark.xp_extension('linalg')
483496
@given(
484497
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
485498
kw=kwargs(full_matrices=todo)
@@ -488,6 +501,7 @@ def test_svd(x, kw):
488501
# res = linalg.svd(x, **kw)
489502
pass
490503

504+
@pytest.mark.xp_extension('linalg')
491505
@given(
492506
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
493507
)
@@ -504,6 +518,7 @@ def test_tensordot(x1, x2, kw):
504518
# res = linalg.tensordot(x1, x2, **kw)
505519
pass
506520

521+
@pytest.mark.xp_extension('linalg')
507522
@given(
508523
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
509524
kw=kwargs(offset=todo)
@@ -521,6 +536,7 @@ def test_vecdot(x1, x2, kw):
521536
# res = linalg.vecdot(x1, x2, **kw)
522537
pass
523538

539+
@pytest.mark.xp_extension('linalg')
524540
@given(
525541
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
526542
kw=kwargs(axis=todo, keepdims=todo, ord=todo)

0 commit comments

Comments
 (0)