Skip to content

Commit 70438ea

Browse files
committed
Add --disable-data-dependent-shapes option
1 parent 4420817 commit 70438ea

File tree

3 files changed

+26
-7
lines changed

3 files changed

+26
-7
lines changed

array_api_tests/test_searching_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_argmin(x, data):
7676
ph.assert_scalar_equals("argmin", int, out_idx, min_i, expected)
7777

7878

79-
# TODO: skip if opted out
79+
@pytest.mark.data_dependent_shapes
8080
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
8181
def test_nonzero(x):
8282
out = xp.nonzero(x)

array_api_tests/test_set_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from . import shape_helpers as sh
1313
from . import xps
1414

15-
pytestmark = pytest.mark.ci
15+
pytestmark = [pytest.mark.ci, pytest.mark.data_dependent_shapes]
1616

1717

1818
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))

conftest.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ def pytest_addoption(parser):
3535
default=[],
3636
help="disable testing for Array API extension(s)",
3737
)
38+
# data-dependent shape
39+
parser.addoption(
40+
"--disable-data-dependent-shapes",
41+
"--disable-dds",
42+
action="store_true",
43+
help="disable testing functions with output shapes dependent on input",
44+
)
3845
# CI
3946
parser.addoption(
4047
"--ci",
@@ -47,6 +54,9 @@ def pytest_configure(config):
4754
config.addinivalue_line(
4855
"markers", "xp_extension(ext): tests an Array API extension"
4956
)
57+
config.addinivalue_line(
58+
"markers", "data_dependent_shapes: output shapes are dependent on inputs"
59+
)
5060
config.addinivalue_line("markers", "ci: primary test")
5161
# Hypothesis
5262
hypothesis_max_examples = config.getoption("--hypothesis-max-examples")
@@ -83,9 +93,15 @@ def xp_has_ext(ext: str) -> bool:
8393

8494
def pytest_collection_modifyitems(config, items):
8595
disabled_exts = config.getoption("--disable-extension")
96+
disabled_dds = config.getoption("--disable-data-dependent-shapes")
8697
ci = config.getoption("--ci")
8798
for item in items:
8899
markers = list(item.iter_markers())
100+
# skip if specified in skips.txt
101+
for id_ in skip_ids:
102+
if item.nodeid.startswith(id_):
103+
item.add_marker(mark.skip(reason="skips.txt"))
104+
break
89105
# skip if disabled or non-existent extension
90106
ext_mark = next((m for m in markers if m.name == "xp_extension"), None)
91107
if ext_mark is not None:
@@ -96,11 +112,14 @@ def pytest_collection_modifyitems(config, items):
96112
)
97113
elif not xp_has_ext(ext):
98114
item.add_marker(mark.skip(reason=f"{ext} not found in array module"))
99-
# skip if specified in skips.txt
100-
for id_ in skip_ids:
101-
if item.nodeid.startswith(id_):
102-
item.add_marker(mark.skip(reason="skips.txt"))
103-
break
115+
# skip if disabled by dds flag
116+
if disabled_dds:
117+
for m in markers:
118+
if m.name == "data_dependent_shapes":
119+
item.add_marker(
120+
mark.skip(reason="disabled via --disable-data-dependent-shapes")
121+
)
122+
break
104123
# skip if test not appropiate for CI
105124
if ci:
106125
ci_mark = next((m for m in markers if m.name == "ci"), None)

0 commit comments

Comments
 (0)