Skip to content

Commit 4431440

Browse files
committed
Rudimentary manual arguments for uninspectable signatures
1 parent d9c4fe0 commit 4431440

File tree

2 files changed

+148
-99
lines changed

2 files changed

+148
-99
lines changed

array_api_tests/dtype_helpers.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import re
2+
from collections import defaultdict
23
from collections.abc import Mapping
34
from functools import lru_cache
4-
from inspect import signature
5-
from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union
5+
from typing import Any, DefaultDict, NamedTuple, Sequence, Tuple, Union
66
from warnings import warn
77

88
from . import _array_module as xp
@@ -323,16 +323,14 @@ def result_type(*dtypes: DataType):
323323
"numeric": numeric_dtypes,
324324
"integer or boolean": bool_and_all_int_dtypes,
325325
}
326-
func_in_dtypes: Dict[str, Tuple[DataType, ...]] = {}
326+
func_in_dtypes: DefaultDict[str, Tuple[DataType, ...]] = defaultdict(lambda: all_dtypes)
327327
for name, func in name_to_func.items():
328328
if m := r_in_dtypes.search(func.__doc__):
329329
dtype_category = m.group(1)
330330
if dtype_category == "numeric" and r_int_note.search(func.__doc__):
331331
dtype_category = "floating-point"
332332
dtypes = category_to_dtypes[dtype_category]
333333
func_in_dtypes[name] = dtypes
334-
elif any("x" in name for name in signature(func).parameters.keys()):
335-
func_in_dtypes[name] = all_dtypes
336334
# See https://github.com/data-apis/array-api/pull/413
337335
func_in_dtypes["expm1"] = float_dtypes
338336

array_api_tests/test_signatures.py

+145-94
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,17 @@ def squeeze(x, /, axis):
2020
...
2121
2222
"""
23+
from collections import defaultdict
2324
from inspect import Parameter, Signature, signature
2425
from types import FunctionType
25-
from typing import Any, Callable, Dict, List, Literal, get_args
26+
from typing import Any, Callable, Dict, Literal, get_args
27+
from warnings import warn
2628

2729
import pytest
28-
from hypothesis import given, note, settings
29-
from hypothesis import strategies as st
30-
from hypothesis.strategies import DataObject
3130

3231
from . import dtype_helpers as dh
33-
from . import hypothesis_helpers as hh
34-
from . import xps
35-
from ._array_module import _UndefinedStub
3632
from ._array_module import mod as xp
37-
from .stubs import array_methods, category_to_funcs, extension_to_funcs
38-
from .typing import Array, DataType
33+
from .stubs import array_methods, category_to_funcs, extension_to_funcs, name_to_func
3934

4035
pytestmark = pytest.mark.ci
4136

@@ -101,17 +96,7 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature):
10196
)
10297

10398

104-
def get_dtypes_strategy(func_name: str) -> st.SearchStrategy[DataType]:
105-
if func_name in dh.func_in_dtypes.keys():
106-
dtypes = dh.func_in_dtypes[func_name]
107-
if hh.FILTER_UNDEFINED_DTYPES:
108-
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
109-
return st.sampled_from(dtypes)
110-
else:
111-
return xps.scalar_dtypes()
112-
113-
114-
def make_pretty_func(func_name: str, *args: Any, **kwargs: Any):
99+
def make_pretty_func(func_name: str, *args: Any, **kwargs: Any) -> str:
115100
f_sig = f"{func_name}("
116101
f_sig += ", ".join(str(a) for a in args)
117102
if len(kwargs) != 0:
@@ -122,96 +107,161 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any):
122107
return f_sig
123108

124109

125-
matrixy_funcs: List[FunctionType] = [
126-
*category_to_funcs["linear_algebra"],
127-
*extension_to_funcs["linalg"],
110+
# We test uninspectable signatures by passing valid, manually-defined arguments
111+
# to the signature's function/method.
112+
#
113+
# Arguments which require use of the array module are specified as string
114+
# expressions to be eval()'d on runtime. This is as opposed to just using the
115+
# array module whilst setting up the tests, which is prone to halt the entire
116+
# test suite if an array module doesn't support a given expression.
117+
func_to_specified_args = defaultdict(
118+
dict,
119+
{
120+
"permute_dims": {"axes": 0},
121+
"reshape": {"shape": (1, 5)},
122+
"broadcast_to": {"shape": (1, 5)},
123+
"asarray": {"obj": [0, 1, 2, 3, 4]},
124+
"full_like": {"fill_value": 42},
125+
"matrix_power": {"n": 2},
126+
},
127+
)
128+
func_to_specified_arg_exprs = defaultdict(
129+
dict,
130+
{
131+
"stack": {"arrays": "[xp.ones((5,)), xp.ones((5,))]"},
132+
"iinfo": {"type": "xp.int64"},
133+
"finfo": {"type": "xp.float64"},
134+
"logaddexp": {a: "xp.ones((5,), dtype=xp.float64)" for a in ["x1", "x2"]},
135+
},
136+
)
137+
# We default most array arguments heuristically. As functions/methods work only
138+
# with arrays of certain dtypes and shapes, we specify only supported arrays
139+
# respective to the function.
140+
casty_names = ["__bool__", "__int__", "__float__", "__complex__", "__index__"]
141+
matrixy_names = [
142+
f.__name__
143+
for f in category_to_funcs["linear_algebra"] + extension_to_funcs["linalg"]
128144
]
129-
matrixy_names: List[str] = [f.__name__ for f in matrixy_funcs]
130145
matrixy_names += ["__matmul__", "triu", "tril"]
146+
for func_name, func in name_to_func.items():
147+
stub_sig = signature(func)
148+
array_argnames = set(stub_sig.parameters.keys()) & {"x", "x1", "x2", "other"}
149+
if func in array_methods:
150+
array_argnames.add("self")
151+
array_argnames -= set(func_to_specified_arg_exprs[func_name].keys())
152+
if len(array_argnames) > 0:
153+
in_dtypes = dh.func_in_dtypes[func_name]
154+
for dtype_name in ["float64", "bool", "int64", "complex128"]:
155+
# We try float64 first because uninspectable numerical functions
156+
# tend to support float inputs first-and-foremost (i.e. PyTorch)
157+
try:
158+
dtype = getattr(xp, dtype_name)
159+
except AttributeError:
160+
pass
161+
else:
162+
if dtype in in_dtypes:
163+
if func_name in casty_names:
164+
shape = ()
165+
elif func_name in matrixy_names:
166+
shape = (2, 2)
167+
else:
168+
shape = (5,)
169+
fallback_array_expr = f"xp.ones({shape}, dtype=xp.{dtype_name})"
170+
break
171+
else:
172+
warn(
173+
f"{dh.func_in_dtypes['{func_name}']}={in_dtypes} seemingly does "
174+
"not contain any assumed dtypes, so skipping specifying fallback array."
175+
)
176+
continue
177+
for argname in array_argnames:
178+
func_to_specified_arg_exprs[func_name][argname] = fallback_array_expr
179+
180+
181+
def _test_uninspectable_func(func_name: str, func: Callable, stub_sig: Signature):
182+
if func_name in matrixy_names:
183+
pytest.xfail("TODO")
131184

185+
params = list(stub_sig.parameters.values())
132186

133-
@given(data=st.data())
134-
@settings(max_examples=1)
135-
def _test_uninspectable_func(
136-
func_name: str, func: Callable, stub_sig: Signature, array: Array, data: DataObject
137-
):
138-
skip_msg = (
139-
f"Signature for {func_name}() is not inspectable "
140-
"and is too troublesome to test for otherwise"
187+
if len(params) == 0:
188+
func()
189+
return
190+
191+
uninspectable_msg = (
192+
f"Note {func_name}() is not inspectable so arguments are passed "
193+
"manually to test the signature."
141194
)
142-
if func_name in [
143-
# 0d shapes
144-
"__bool__",
145-
"__int__",
146-
"__index__",
147-
"__float__",
148-
# x2 elements must be >=0
149-
"pow",
150-
"bitwise_left_shift",
151-
"bitwise_right_shift",
152-
# axis default invalid with 0d shapes
153-
"sort",
154-
# shape requirements
155-
*matrixy_names,
156-
]:
157-
pytest.skip(skip_msg)
158-
159-
param_to_value: Dict[Parameter, Any] = {}
160-
for param in stub_sig.parameters.values():
161-
if param.kind in [Parameter.POSITIONAL_OR_KEYWORD, *VAR_KINDS]:
195+
196+
argname_to_arg = func_to_specified_args[func_name]
197+
argname_to_expr = func_to_specified_arg_exprs[func_name]
198+
for argname, expr in argname_to_expr.items():
199+
assert argname not in argname_to_arg.keys() # sanity check
200+
try:
201+
argname_to_arg[argname] = eval(expr, {"xp": xp})
202+
except Exception as e:
162203
pytest.skip(
163-
skip_msg + f" (because '{param.name}' is a {kind_to_str[param.kind]})"
164-
)
165-
elif param.default != Parameter.empty:
166-
value = param.default
167-
elif param.name in ["x", "x1"]:
168-
dtypes = get_dtypes_strategy(func_name)
169-
value = data.draw(
170-
xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)), label=param.name
204+
f"Exception occured when evaluating {argname}={expr}: {e}\n"
205+
f"{uninspectable_msg}"
171206
)
172-
elif param.name in ["x2", "other"]:
173-
if param.name == "x2":
174-
assert "x1" in [p.name for p in param_to_value.keys()] # sanity check
175-
orig = next(v for p, v in param_to_value.items() if p.name == "x1")
207+
208+
posargs = []
209+
posorkw_args = {}
210+
kwargs = {}
211+
no_arg_msg = (
212+
"We have no argument specified for '{}'. Please ensure you're using "
213+
"the latest version of array-api-tests, then open an issue if one "
214+
f"doesn't already exist. {uninspectable_msg}"
215+
)
216+
for param in params:
217+
if param.kind == Parameter.POSITIONAL_ONLY:
218+
try:
219+
posargs.append(argname_to_arg[param.name])
220+
except KeyError:
221+
pytest.skip(no_arg_msg.format(param.name))
222+
elif param.kind == Parameter.POSITIONAL_OR_KEYWORD:
223+
if param.default == Parameter.empty:
224+
try:
225+
posorkw_args[param.name] = argname_to_arg[param.name]
226+
except KeyError:
227+
pytest.skip(no_arg_msg.format(param.name))
176228
else:
177-
assert array is not None # sanity check
178-
orig = array
179-
value = data.draw(
180-
xps.arrays(dtype=orig.dtype, shape=orig.shape), label=param.name
181-
)
229+
assert argname_to_arg[param.name]
230+
posorkw_args[param.name] = param.default
231+
elif param.kind == Parameter.KEYWORD_ONLY:
232+
assert param.default != Parameter.empty # sanity check
233+
kwargs[param.name] = param.default
182234
else:
183-
pytest.skip(
184-
skip_msg + f" (because no default was found for argument {param.name})"
185-
)
186-
param_to_value[param] = value
187-
188-
args: List[Any] = [
189-
v for p, v in param_to_value.items() if p.kind == Parameter.POSITIONAL_ONLY
190-
]
191-
kwargs: Dict[str, Any] = {
192-
p.name: v for p, v in param_to_value.items() if p.kind == Parameter.KEYWORD_ONLY
193-
}
194-
f_func = make_pretty_func(func_name, *args, **kwargs)
195-
note(f"trying {f_func}")
196-
func(*args, **kwargs)
235+
assert param.kind in VAR_KINDS # sanity check
236+
pytest.skip(no_arg_msg.format(param.name))
237+
if len(posorkw_args) == 0:
238+
func(*posargs, **kwargs)
239+
else:
240+
func(*posargs, **posorkw_args, **kwargs)
241+
# TODO: test all positional and keyword permutations of pos-or-kw args
197242

198243

199-
def _test_func_signature(func: Callable, stub: FunctionType, array=None):
244+
def _test_func_signature(func: Callable, stub: FunctionType, is_method=False):
200245
stub_sig = signature(stub)
201246
# If testing against array, ignore 'self' arg in stub as it won't be present
202247
# in func (which should be a method).
203-
if array is not None:
248+
if is_method:
204249
stub_params = list(stub_sig.parameters.values())
205-
del stub_params[0]
250+
if stub_params[0].name == "self":
251+
del stub_params[0]
206252
stub_sig = Signature(
207253
parameters=stub_params, return_annotation=stub_sig.return_annotation
208254
)
209255

210256
try:
211257
sig = signature(func)
212-
_test_inspectable_func(sig, stub_sig)
213258
except ValueError:
214-
_test_uninspectable_func(stub.__name__, func, stub_sig, array)
259+
try:
260+
_test_uninspectable_func(stub.__name__, func, stub_sig)
261+
except Exception as e:
262+
raise e from None # suppress parent exception for cleaner pytest output
263+
else:
264+
_test_inspectable_func(sig, stub_sig)
215265

216266

217267
@pytest.mark.parametrize(
@@ -245,11 +295,12 @@ def test_extension_func_signature(extension: str, stub: FunctionType):
245295

246296

247297
@pytest.mark.parametrize("stub", array_methods, ids=lambda f: f.__name__)
248-
@given(st.data())
249-
@settings(max_examples=1)
250-
def test_array_method_signature(stub: FunctionType, data: DataObject):
251-
dtypes = get_dtypes_strategy(stub.__name__)
252-
x = data.draw(xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)), label="x")
298+
def test_array_method_signature(stub: FunctionType):
299+
x_expr = func_to_specified_arg_exprs[stub.__name__]["self"]
300+
try:
301+
x = eval(x_expr, {"xp": xp})
302+
except Exception as e:
303+
pytest.skip(f"Exception occured when evaluating x={x_expr}: {e}")
253304
assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}"
254305
method = getattr(x, stub.__name__)
255-
_test_func_signature(method, stub, array=x)
306+
_test_func_signature(method, stub, is_method=True)

0 commit comments

Comments
 (0)