Skip to content

Commit 15114b7

Browse files
committed
Test uninspectable array methods
1 parent 17a56ba commit 15114b7

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

array_api_tests/test_signatures.py

+18-15
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def squeeze(x, /, axis):
3333
from ._array_module import _UndefinedStub
3434
from ._array_module import mod as xp
3535
from .stubs import array_methods, category_to_funcs, extension_to_funcs
36-
from .typing import DataType
36+
from .typing import Array, DataType
3737

3838
pytestmark = pytest.mark.ci
3939

@@ -112,7 +112,8 @@ def make_pretty_func(func_name: str, args: Sequence[Any], kwargs: Dict[str, Any]
112112

113113

114114
matrixy_funcs: List[FunctionType] = [
115-
*category_to_funcs["linear_algebra"], *extension_to_funcs["linalg"]
115+
*category_to_funcs["linear_algebra"],
116+
*extension_to_funcs["linalg"],
116117
]
117118
matrixy_names: List[str] = [f.__name__ for f in matrixy_funcs]
118119
matrixy_names += ["__matmul__", "triu", "tril"]
@@ -121,7 +122,7 @@ def make_pretty_func(func_name: str, args: Sequence[Any], kwargs: Dict[str, Any]
121122
@given(data=st.data())
122123
@settings(max_examples=1)
123124
def _test_uninspectable_func(
124-
func_name: str, func: Callable, stub_sig: Signature, data: DataObject
125+
func_name: str, func: Callable, stub_sig: Signature, array: Array, data: DataObject
125126
):
126127
skip_msg = (
127128
f"Signature for {func_name}() is not inspectable "
@@ -153,12 +154,15 @@ def _test_uninspectable_func(
153154
value = data.draw(
154155
xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)), label=param.name
155156
)
156-
elif param.name == "x2":
157-
# sanity check
158-
assert "x1" in [p.name for p in param_to_value.keys()]
159-
x1 = next(v for p, v in param_to_value.items() if p.name == "x1")
157+
elif param.name in ["x2", "other"]:
158+
if param.name == "x2":
159+
assert "x1" in [p.name for p in param_to_value.keys()] # sanity check
160+
orig = next(v for p, v in param_to_value.items() if p.name == "x1")
161+
else:
162+
assert array is not None # sanity check
163+
orig = array
160164
value = data.draw(
161-
xps.arrays(dtype=x1.dtype, shape=x1.shape), label=param.name
165+
xps.arrays(dtype=orig.dtype, shape=orig.shape), label=param.name
162166
)
163167
else:
164168
pytest.skip(
@@ -177,11 +181,11 @@ def _test_uninspectable_func(
177181
func(*args, **kwargs)
178182

179183

180-
def _test_func_signature(
181-
func: Callable, stub: FunctionType, ignore_first_stub_param: bool = False
182-
):
184+
def _test_func_signature(func: Callable, stub: FunctionType, array=None):
183185
stub_sig = signature(stub)
184-
if ignore_first_stub_param:
186+
# If testing against array, ignore 'self' arg in stub as it won't be present
187+
# in func (which should be an array method).
188+
if array is not None:
185189
stub_params = list(stub_sig.parameters.values())
186190
del stub_params[0]
187191
stub_sig = Signature(
@@ -192,7 +196,7 @@ def _test_func_signature(
192196
sig = signature(func)
193197
_test_inspectable_func(sig, stub_sig)
194198
except ValueError:
195-
_test_uninspectable_func(stub.__name__, func, stub_sig)
199+
_test_uninspectable_func(stub.__name__, func, stub_sig, array)
196200

197201

198202
@pytest.mark.parametrize(
@@ -233,5 +237,4 @@ def test_array_method_signature(stub: FunctionType, data: DataObject):
233237
x = data.draw(xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)), label="x")
234238
assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}"
235239
method = getattr(x, stub.__name__)
236-
# Ignore 'self' arg in stub, which won't be present in instantiated objects.
237-
_test_func_signature(method, stub, ignore_first_stub_param=True)
240+
_test_func_signature(method, stub, array=x)

0 commit comments

Comments
 (0)