Skip to content

Commit 7d2203b

Browse files
committed
Fix xp_extension() collection, apply to test_linalg.py
1 parent 59628f1 commit 7d2203b

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

array_api_tests/conftest.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ def pytest_addoption(parser):
1414
# Enable extensions
1515
parser.addoption(
1616
'--ext',
17-
'--disable-extensions',
17+
'--disable-extension',
1818
nargs='+',
1919
default=[],
20-
help='disable testing for Array API extensions',
20+
help='disable testing for Array API extension(s)',
2121
)
2222
# Hypothesis max examples
2323
# See https://github.com/HypothesisWorks/hypothesis/issues/2434
@@ -65,13 +65,18 @@ def xp_has_ext(ext: str) -> bool:
6565

6666

6767
def pytest_collection_modifyitems(config, items):
68-
disabled_exts = config.getoption('--disable-extensions')
68+
disabled_exts = config.getoption('--disable-extension')
6969
for item in items:
70-
if 'xp_extension' in item.keywords:
71-
ext = item.keywords['xp_extension'].args[0]
72-
if ext in disabled_exts:
73-
item.add_marker(
74-
mark.skip(reason=f'{ext} disabled in --disable-extensions')
75-
)
76-
elif not xp_has_ext(ext):
77-
item.add_marker(mark.skip(reason=f'{ext} not found in array module'))
70+
try:
71+
ext_mark = next(
72+
mark for mark in item.iter_markers() if mark.name == 'xp_extension'
73+
)
74+
except StopIteration:
75+
continue
76+
ext = ext_mark.args[0]
77+
if ext in disabled_exts:
78+
item.add_marker(
79+
mark.skip(reason=f'{ext} disabled in --disable-extensions')
80+
)
81+
elif not xp_has_ext(ext):
82+
item.add_marker(mark.skip(reason=f'{ext} not found in array module'))

array_api_tests/test_linalg.py

+5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
1414
"""
1515

16+
import pytest
1617
from hypothesis import assume, given
1718
from hypothesis.strategies import (booleans, composite, none, tuples, integers,
1819
shared, sampled_from)
@@ -33,6 +34,10 @@
3334
from . import _array_module
3435
from ._array_module import linalg
3536

37+
38+
pytestmark = [pytest.mark.xp_extension('linalg')]
39+
40+
3641
# Standin strategy for not yet implemented tests
3742
todo = none()
3843

0 commit comments

Comments
 (0)