Skip to content

Commit 59628f1

Browse files
committed
Fix xp_extension() mark collection, apply to test_signature.py
1 parent 1217563 commit 59628f1

File tree

3 files changed

+51
-9
lines changed

3 files changed

+51
-9
lines changed

conftest.py renamed to array_api_tests/conftest.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
from functools import lru_cache
2+
13
from hypothesis import settings
24
from pytest import mark
35

6+
from . import _array_module as xp
7+
from ._array_module import _UndefinedStub
8+
49

510
settings.register_profile('xp_default', deadline=800)
611

@@ -9,10 +14,10 @@ def pytest_addoption(parser):
914
# Enable extensions
1015
parser.addoption(
1116
'--ext',
12-
'--extensions',
17+
'--disable-extensions',
1318
nargs='+',
1419
default=[],
15-
help='enable testing for Array API extensions',
20+
help='disable testing for Array API extensions',
1621
)
1722
# Hypothesis max examples
1823
# See https://github.com/HypothesisWorks/hypothesis/issues/2434
@@ -51,10 +56,22 @@ def pytest_configure(config):
5156
settings.load_profile('xp_default')
5257

5358

59+
@lru_cache
60+
def xp_has_ext(ext: str) -> bool:
61+
try:
62+
return not isinstance(getattr(xp, ext), _UndefinedStub)
63+
except AttributeError:
64+
return False
65+
66+
5467
def pytest_collection_modifyitems(config, items):
55-
exts = config.getoption('--extensions')
68+
disabled_exts = config.getoption('--disable-extensions')
5669
for item in items:
5770
if 'xp_extension' in item.keywords:
5871
ext = item.keywords['xp_extension'].args[0]
59-
if ext not in exts:
60-
item.add_marker(mark.skip(reason=f'{ext} not enabled in --extensions'))
72+
if ext in disabled_exts:
73+
item.add_marker(
74+
mark.skip(reason=f'{ext} disabled in --disable-extensions')
75+
)
76+
elif not xp_has_ext(ext):
77+
item.add_marker(mark.skip(reason=f'{ext} not found in array module'))
+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from ..test_signatures import extension_module
2+
from ..conftest import xp_has_ext
3+
4+
5+
def test_extension_module_is_extension():
6+
assert extension_module('linalg')
7+
8+
9+
def test_extension_func_is_not_extension():
10+
assert not extension_module('linalg.cross')
11+
12+
13+
def test_xp_has_ext():
14+
assert not xp_has_ext('nonexistent_extension')

array_api_tests/test_signatures.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,18 @@ def extension_module(name):
2626
if extension_module(n):
2727
extension_module_names.extend([f'{n}.{i}' for i in getattr(function_stubs, n).__all__])
2828

29-
all_names = function_stubs.__all__ + extension_module_names
29+
30+
params = []
31+
for name in function_stubs.__all__:
32+
marks = []
33+
if extension_module(name):
34+
marks.append(pytest.mark.xp_extension(name))
35+
params.append(pytest.param(name, marks=marks))
36+
for name in extension_module_names:
37+
ext = name.split('.')[0]
38+
mark = pytest.mark.xp_extension(ext)
39+
params.append(pytest.param(name, marks=[mark]))
40+
3041

3142
def array_method(name):
3243
return stub_module(name) == 'array_object'
@@ -130,7 +141,7 @@ def example_argument(arg, func_name, dtype):
130141
else:
131142
raise RuntimeError(f"Don't know how to test argument {arg}. Please update test_signatures.py")
132143

133-
@pytest.mark.parametrize('name', all_names)
144+
@pytest.mark.parametrize('name', params)
134145
def test_has_names(name):
135146
if extension_module(name):
136147
assert hasattr(mod, name), f'{mod_name} is missing the {name} extension'
@@ -146,7 +157,7 @@ def test_has_names(name):
146157
else:
147158
assert hasattr(mod, name), f"{mod_name} is missing the {function_category(name)} function {name}()"
148159

149-
@pytest.mark.parametrize('name', all_names)
160+
@pytest.mark.parametrize('name', params)
150161
def test_function_positional_args(name):
151162
# Note: We can't actually test that positional arguments are
152163
# positional-only, as that would require knowing the argument name and
@@ -223,7 +234,7 @@ def test_function_positional_args(name):
223234
# NumPy ufuncs raise ValueError instead of TypeError
224235
raises((TypeError, ValueError), lambda: mod_func(*args[:n]), f"{name}() should not accept {n} positional arguments")
225236

226-
@pytest.mark.parametrize('name', all_names)
237+
@pytest.mark.parametrize('name', params)
227238
def test_function_keyword_only_args(name):
228239
if extension_module(name):
229240
return

0 commit comments

Comments
 (0)