Skip to content

Commit 971cd95

Browse files
committed
Use data draws over examples
More performative, even for `max_examples=1`
1 parent e78e3d5 commit 971cd95

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

array_api_tests/test_signatures.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def squeeze(x, /, axis):
2424
from typing import Any, Callable, Dict, List, Literal, Sequence, get_args
2525

2626
import pytest
27+
from hypothesis import given, note, settings
2728
from hypothesis import strategies as st
29+
from hypothesis.strategies import DataObject
2830

2931
from . import dtype_helpers as dh
3032
from . import hypothesis_helpers as hh
@@ -117,7 +119,11 @@ def make_pretty_func(func_name: str, args: Sequence[Any], kwargs: Dict[str, Any]
117119
matrixy_funcs += ["__matmul__", "triu", "tril"]
118120

119121

120-
def _test_uninspectable_func(func_name: str, func: Callable, stub_sig: Signature):
122+
@given(data=st.data())
123+
@settings(max_examples=1)
124+
def _test_uninspectable_func(
125+
func_name: str, func: Callable, stub_sig: Signature, data: DataObject
126+
):
121127
skip_msg = (
122128
f"Signature for {func_name}() is not inspectable "
123129
"and is too troublesome to test for otherwise"
@@ -145,12 +151,16 @@ def _test_uninspectable_func(func_name: str, func: Callable, stub_sig: Signature
145151
value = param.default
146152
elif param.name in ["x", "x1"]:
147153
dtypes = get_dtypes_strategy(func_name)
148-
value = xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)).example()
154+
value = data.draw(
155+
xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)), label=param.name
156+
)
149157
elif param.name == "x2":
150158
# sanity check
151159
assert "x1" in [p.name for p in param_to_value.keys()]
152160
x1 = next(v for p, v in param_to_value.items() if p.name == "x1")
153-
value = xps.arrays(dtype=x1.dtype, shape=x1.shape).example()
161+
value = data.draw(
162+
xps.arrays(dtype=x1.dtype, shape=x1.shape), label=param.name
163+
)
154164
else:
155165
pytest.skip(
156166
skip_msg + f" (because no default was found for argument {param.name})"
@@ -164,7 +174,7 @@ def _test_uninspectable_func(func_name: str, func: Callable, stub_sig: Signature
164174
p.name: v for p, v in param_to_value.items() if p.kind == Parameter.KEYWORD_ONLY
165175
}
166176
f_func = make_pretty_func(func_name, args, kwargs)
167-
print(f"trying {f_func}")
177+
note(f"trying {f_func}")
168178
func(*args, **kwargs)
169179

170180

@@ -217,9 +227,11 @@ def test_extension_func_signature(extension: str, stub: FunctionType):
217227

218228

219229
@pytest.mark.parametrize("stub", array_methods, ids=lambda f: f.__name__)
220-
def test_array_method_signature(stub: FunctionType):
230+
@given(st.data())
231+
@settings(max_examples=1)
232+
def test_array_method_signature(stub: FunctionType, data: DataObject):
221233
dtypes = get_dtypes_strategy(stub.__name__)
222-
x = xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)).example()
234+
x = data.draw(xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)), label="x")
223235
assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}"
224236
method = getattr(x, stub.__name__)
225237
# Ignore 'self' arg in stub, which won't be present in instantiated objects.

0 commit comments

Comments
 (0)