@@ -35,6 +35,13 @@ def pytest_addoption(parser):
35
35
default = [],
36
36
help = "disable testing for Array API extension(s)" ,
37
37
)
38
+ # data-dependent shape
39
+ parser .addoption (
40
+ "--disable-data-dependent-shapes" ,
41
+ "--disable-dds" ,
42
+ action = "store_true" ,
43
+ help = "disable testing functions with output shapes dependent on input" ,
44
+ )
38
45
# CI
39
46
parser .addoption (
40
47
"--ci" ,
@@ -47,6 +54,9 @@ def pytest_configure(config):
47
54
config .addinivalue_line (
48
55
"markers" , "xp_extension(ext): tests an Array API extension"
49
56
)
57
+ config .addinivalue_line (
58
+ "markers" , "data_dependent_shapes: output shapes are dependent on inputs"
59
+ )
50
60
config .addinivalue_line ("markers" , "ci: primary test" )
51
61
# Hypothesis
52
62
hypothesis_max_examples = config .getoption ("--hypothesis-max-examples" )
@@ -83,9 +93,15 @@ def xp_has_ext(ext: str) -> bool:
83
93
84
94
def pytest_collection_modifyitems (config , items ):
85
95
disabled_exts = config .getoption ("--disable-extension" )
96
+ disabled_dds = config .getoption ("--disable-data-dependent-shapes" )
86
97
ci = config .getoption ("--ci" )
87
98
for item in items :
88
99
markers = list (item .iter_markers ())
100
+ # skip if specified in skips.txt
101
+ for id_ in skip_ids :
102
+ if item .nodeid .startswith (id_ ):
103
+ item .add_marker (mark .skip (reason = "skips.txt" ))
104
+ break
89
105
# skip if disabled or non-existent extension
90
106
ext_mark = next ((m for m in markers if m .name == "xp_extension" ), None )
91
107
if ext_mark is not None :
@@ -96,11 +112,14 @@ def pytest_collection_modifyitems(config, items):
96
112
)
97
113
elif not xp_has_ext (ext ):
98
114
item .add_marker (mark .skip (reason = f"{ ext } not found in array module" ))
99
- # skip if specified in skips.txt
100
- for id_ in skip_ids :
101
- if item .nodeid .startswith (id_ ):
102
- item .add_marker (mark .skip (reason = "skips.txt" ))
103
- break
115
+ # skip if disabled by dds flag
116
+ if disabled_dds :
117
+ for m in markers :
118
+ if m .name == "data_dependent_shapes" :
119
+ item .add_marker (
120
+ mark .skip (reason = "disabled via --disable-data-dependent-shapes" )
121
+ )
122
+ break
104
123
# skip if test not appropiate for CI
105
124
if ci :
106
125
ci_mark = next ((m for m in markers if m .name == "ci" ), None )
0 commit comments