Skip to content

Commit c2c6a9f

Browse files
authored
Merge pull request #177 from honno/fallback-signature-tests
Specify/infer arguments for testing uninspectable signatures
2 parents d2d43a7 + 0d54b71 commit c2c6a9f

File tree

2 files changed

+154
-99
lines changed

2 files changed

+154
-99
lines changed

array_api_tests/dtype_helpers.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import re
2+
from collections import defaultdict
23
from collections.abc import Mapping
34
from functools import lru_cache
4-
from inspect import signature
5-
from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union
5+
from typing import Any, DefaultDict, NamedTuple, Sequence, Tuple, Union
66
from warnings import warn
77

88
from . import _array_module as xp
@@ -323,16 +323,14 @@ def result_type(*dtypes: DataType):
323323
"numeric": numeric_dtypes,
324324
"integer or boolean": bool_and_all_int_dtypes,
325325
}
326-
func_in_dtypes: Dict[str, Tuple[DataType, ...]] = {}
326+
func_in_dtypes: DefaultDict[str, Tuple[DataType, ...]] = defaultdict(lambda: all_dtypes)
327327
for name, func in name_to_func.items():
328328
if m := r_in_dtypes.search(func.__doc__):
329329
dtype_category = m.group(1)
330330
if dtype_category == "numeric" and r_int_note.search(func.__doc__):
331331
dtype_category = "floating-point"
332332
dtypes = category_to_dtypes[dtype_category]
333333
func_in_dtypes[name] = dtypes
334-
elif any("x" in name for name in signature(func).parameters.keys()):
335-
func_in_dtypes[name] = all_dtypes
336334
# See https://github.com/data-apis/array-api/pull/413
337335
func_in_dtypes["expm1"] = float_dtypes
338336

array_api_tests/test_signatures.py

+151-94
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,18 @@ def squeeze(x, /, axis):
2020
...
2121
2222
"""
23+
from collections import defaultdict
24+
from copy import copy
2325
from inspect import Parameter, Signature, signature
2426
from types import FunctionType
25-
from typing import Any, Callable, Dict, List, Literal, get_args
27+
from typing import Any, Callable, Dict, Literal, get_args
28+
from warnings import warn
2629

2730
import pytest
28-
from hypothesis import given, note, settings
29-
from hypothesis import strategies as st
30-
from hypothesis.strategies import DataObject
3131

3232
from . import dtype_helpers as dh
33-
from . import hypothesis_helpers as hh
34-
from . import xps
35-
from ._array_module import _UndefinedStub
3633
from ._array_module import mod as xp
37-
from .stubs import array_methods, category_to_funcs, extension_to_funcs
38-
from .typing import Array, DataType
34+
from .stubs import array_methods, category_to_funcs, extension_to_funcs, name_to_func
3935

4036
pytestmark = pytest.mark.ci
4137

@@ -93,24 +89,15 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature):
9389
stub_param.name in sig.parameters.keys()
9490
), f"Argument '{stub_param.name}' missing from signature"
9591
param = next(p for p in params if p.name == stub_param.name)
92+
f_stub_kind = kind_to_str[stub_param.kind]
9693
assert param.kind in [stub_param.kind, Parameter.POSITIONAL_OR_KEYWORD,], (
9794
f"{param.name} is a {kind_to_str[param.kind]}, "
9895
f"but should be a {f_stub_kind} "
9996
f"(or at least a {kind_to_str[ParameterKind.POSITIONAL_OR_KEYWORD]})"
10097
)
10198

10299

103-
def get_dtypes_strategy(func_name: str) -> st.SearchStrategy[DataType]:
104-
if func_name in dh.func_in_dtypes.keys():
105-
dtypes = dh.func_in_dtypes[func_name]
106-
if hh.FILTER_UNDEFINED_DTYPES:
107-
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
108-
return st.sampled_from(dtypes)
109-
else:
110-
return xps.scalar_dtypes()
111-
112-
113-
def make_pretty_func(func_name: str, *args: Any, **kwargs: Any):
100+
def make_pretty_func(func_name: str, *args: Any, **kwargs: Any) -> str:
114101
f_sig = f"{func_name}("
115102
f_sig += ", ".join(str(a) for a in args)
116103
if len(kwargs) != 0:
@@ -121,96 +108,165 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any):
121108
return f_sig
122109

123110

124-
matrixy_funcs: List[FunctionType] = [
125-
*category_to_funcs["linear_algebra"],
126-
*extension_to_funcs["linalg"],
111+
# We test uninspectable signatures by passing valid, manually-defined arguments
112+
# to the signature's function/method.
113+
#
114+
# Arguments which require use of the array module are specified as string
115+
# expressions to be eval()'d on runtime. This is as opposed to just using the
116+
# array module whilst setting up the tests, which is prone to halt the entire
117+
# test suite if an array module doesn't support a given expression.
118+
func_to_specified_args = defaultdict(
119+
dict,
120+
{
121+
"permute_dims": {"axes": 0},
122+
"reshape": {"shape": (1, 5)},
123+
"broadcast_to": {"shape": (1, 5)},
124+
"asarray": {"obj": [0, 1, 2, 3, 4]},
125+
"full_like": {"fill_value": 42},
126+
"matrix_power": {"n": 2},
127+
},
128+
)
129+
func_to_specified_arg_exprs = defaultdict(
130+
dict,
131+
{
132+
"stack": {"arrays": "[xp.ones((5,)), xp.ones((5,))]"},
133+
"iinfo": {"type": "xp.int64"},
134+
"finfo": {"type": "xp.float64"},
135+
"cholesky": {"x": "xp.asarray([[1, 0], [0, 1]], dtype=xp.float64)"},
136+
"inv": {"x": "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)"},
137+
"solve": {
138+
a: "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)" for a in ["x1", "x2"]
139+
},
140+
},
141+
)
142+
# We default most array arguments heuristically. As functions/methods work only
143+
# with arrays of certain dtypes and shapes, we specify only supported arrays
144+
# respective to the function.
145+
casty_names = ["__bool__", "__int__", "__float__", "__complex__", "__index__"]
146+
matrixy_names = [
147+
f.__name__
148+
for f in category_to_funcs["linear_algebra"] + extension_to_funcs["linalg"]
127149
]
128-
matrixy_names: List[str] = [f.__name__ for f in matrixy_funcs]
129150
matrixy_names += ["__matmul__", "triu", "tril"]
151+
for func_name, func in name_to_func.items():
152+
stub_sig = signature(func)
153+
array_argnames = set(stub_sig.parameters.keys()) & {"x", "x1", "x2", "other"}
154+
if func in array_methods:
155+
array_argnames.add("self")
156+
array_argnames -= set(func_to_specified_arg_exprs[func_name].keys())
157+
if len(array_argnames) > 0:
158+
in_dtypes = dh.func_in_dtypes[func_name]
159+
for dtype_name in ["float64", "bool", "int64", "complex128"]:
160+
# We try float64 first because uninspectable numerical functions
161+
# tend to support float inputs first-and-foremost (i.e. PyTorch)
162+
try:
163+
dtype = getattr(xp, dtype_name)
164+
except AttributeError:
165+
pass
166+
else:
167+
if dtype in in_dtypes:
168+
if func_name in casty_names:
169+
shape = ()
170+
elif func_name in matrixy_names:
171+
shape = (3, 3)
172+
else:
173+
shape = (5,)
174+
fallback_array_expr = f"xp.ones({shape}, dtype=xp.{dtype_name})"
175+
break
176+
else:
177+
warn(
178+
f"{dh.func_in_dtypes['{func_name}']}={in_dtypes} seemingly does "
179+
"not contain any assumed dtypes, so skipping specifying fallback array."
180+
)
181+
continue
182+
for argname in array_argnames:
183+
func_to_specified_arg_exprs[func_name][argname] = fallback_array_expr
184+
130185

186+
def _test_uninspectable_func(func_name: str, func: Callable, stub_sig: Signature):
187+
params = list(stub_sig.parameters.values())
131188

132-
@given(data=st.data())
133-
@settings(max_examples=1)
134-
def _test_uninspectable_func(
135-
func_name: str, func: Callable, stub_sig: Signature, array: Array, data: DataObject
136-
):
137-
skip_msg = (
138-
f"Signature for {func_name}() is not inspectable "
139-
"and is too troublesome to test for otherwise"
189+
if len(params) == 0:
190+
func()
191+
return
192+
193+
uninspectable_msg = (
194+
f"Note {func_name}() is not inspectable so arguments are passed "
195+
"manually to test the signature."
140196
)
141-
if func_name in [
142-
# 0d shapes
143-
"__bool__",
144-
"__int__",
145-
"__index__",
146-
"__float__",
147-
# x2 elements must be >=0
148-
"pow",
149-
"bitwise_left_shift",
150-
"bitwise_right_shift",
151-
# axis default invalid with 0d shapes
152-
"sort",
153-
# shape requirements
154-
*matrixy_names,
155-
]:
156-
pytest.skip(skip_msg)
157-
158-
param_to_value: Dict[Parameter, Any] = {}
159-
for param in stub_sig.parameters.values():
160-
if param.kind in [Parameter.POSITIONAL_OR_KEYWORD, *VAR_KINDS]:
197+
198+
argname_to_arg = copy(func_to_specified_args[func_name])
199+
argname_to_expr = func_to_specified_arg_exprs[func_name]
200+
for argname, expr in argname_to_expr.items():
201+
assert argname not in argname_to_arg.keys() # sanity check
202+
try:
203+
argname_to_arg[argname] = eval(expr, {"xp": xp})
204+
except Exception as e:
161205
pytest.skip(
162-
skip_msg + f" (because '{param.name}' is a {kind_to_str[param.kind]})"
163-
)
164-
elif param.default != Parameter.empty:
165-
value = param.default
166-
elif param.name in ["x", "x1"]:
167-
dtypes = get_dtypes_strategy(func_name)
168-
value = data.draw(
169-
xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)), label=param.name
206+
f"Exception occured when evaluating {argname}={expr}: {e}\n"
207+
f"{uninspectable_msg}"
170208
)
171-
elif param.name in ["x2", "other"]:
172-
if param.name == "x2":
173-
assert "x1" in [p.name for p in param_to_value.keys()] # sanity check
174-
orig = next(v for p, v in param_to_value.items() if p.name == "x1")
209+
210+
posargs = []
211+
posorkw_args = {}
212+
kwargs = {}
213+
no_arg_msg = (
214+
"We have no argument specified for '{}'. Please ensure you're using "
215+
"the latest version of array-api-tests, then open an issue if one "
216+
f"doesn't already exist. {uninspectable_msg}"
217+
)
218+
for param in params:
219+
if param.kind == Parameter.POSITIONAL_ONLY:
220+
try:
221+
posargs.append(argname_to_arg[param.name])
222+
except KeyError:
223+
pytest.skip(no_arg_msg.format(param.name))
224+
elif param.kind == Parameter.POSITIONAL_OR_KEYWORD:
225+
if param.default == Parameter.empty:
226+
try:
227+
posorkw_args[param.name] = argname_to_arg[param.name]
228+
except KeyError:
229+
pytest.skip(no_arg_msg.format(param.name))
175230
else:
176-
assert array is not None # sanity check
177-
orig = array
178-
value = data.draw(
179-
xps.arrays(dtype=orig.dtype, shape=orig.shape), label=param.name
180-
)
231+
assert argname_to_arg[param.name]
232+
posorkw_args[param.name] = param.default
233+
elif param.kind == Parameter.KEYWORD_ONLY:
234+
assert param.default != Parameter.empty # sanity check
235+
kwargs[param.name] = param.default
181236
else:
182-
pytest.skip(
183-
skip_msg + f" (because no default was found for argument {param.name})"
184-
)
185-
param_to_value[param] = value
186-
187-
args: List[Any] = [
188-
v for p, v in param_to_value.items() if p.kind == Parameter.POSITIONAL_ONLY
189-
]
190-
kwargs: Dict[str, Any] = {
191-
p.name: v for p, v in param_to_value.items() if p.kind == Parameter.KEYWORD_ONLY
192-
}
193-
f_func = make_pretty_func(func_name, *args, **kwargs)
194-
note(f"trying {f_func}")
195-
func(*args, **kwargs)
237+
assert param.kind in VAR_KINDS # sanity check
238+
pytest.skip(no_arg_msg.format(param.name))
239+
if len(posorkw_args) == 0:
240+
func(*posargs, **kwargs)
241+
else:
242+
posorkw_name_to_arg_pairs = list(posorkw_args.items())
243+
for i in range(len(posorkw_name_to_arg_pairs), -1, -1):
244+
extra_posargs = [arg for _, arg in posorkw_name_to_arg_pairs[:i]]
245+
extra_kwargs = dict(posorkw_name_to_arg_pairs[i:])
246+
func(*posargs, *extra_posargs, **kwargs, **extra_kwargs)
196247

197248

198-
def _test_func_signature(func: Callable, stub: FunctionType, array=None):
249+
def _test_func_signature(func: Callable, stub: FunctionType, is_method=False):
199250
stub_sig = signature(stub)
200251
# If testing against array, ignore 'self' arg in stub as it won't be present
201252
# in func (which should be a method).
202-
if array is not None:
253+
if is_method:
203254
stub_params = list(stub_sig.parameters.values())
204-
del stub_params[0]
255+
if stub_params[0].name == "self":
256+
del stub_params[0]
205257
stub_sig = Signature(
206258
parameters=stub_params, return_annotation=stub_sig.return_annotation
207259
)
208260

209261
try:
210262
sig = signature(func)
211-
_test_inspectable_func(sig, stub_sig)
212263
except ValueError:
213-
_test_uninspectable_func(stub.__name__, func, stub_sig, array)
264+
try:
265+
_test_uninspectable_func(stub.__name__, func, stub_sig)
266+
except Exception as e:
267+
raise e from None # suppress parent exception for cleaner pytest output
268+
else:
269+
_test_inspectable_func(sig, stub_sig)
214270

215271

216272
@pytest.mark.parametrize(
@@ -244,11 +300,12 @@ def test_extension_func_signature(extension: str, stub: FunctionType):
244300

245301

246302
@pytest.mark.parametrize("stub", array_methods, ids=lambda f: f.__name__)
247-
@given(st.data())
248-
@settings(max_examples=1)
249-
def test_array_method_signature(stub: FunctionType, data: DataObject):
250-
dtypes = get_dtypes_strategy(stub.__name__)
251-
x = data.draw(xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)), label="x")
303+
def test_array_method_signature(stub: FunctionType):
304+
x_expr = func_to_specified_arg_exprs[stub.__name__]["self"]
305+
try:
306+
x = eval(x_expr, {"xp": xp})
307+
except Exception as e:
308+
pytest.skip(f"Exception occured when evaluating x={x_expr}: {e}")
252309
assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}"
253310
method = getattr(x, stub.__name__)
254-
_test_func_signature(method, stub, array=x)
311+
_test_func_signature(method, stub, is_method=True)

0 commit comments

Comments
 (0)