@@ -33,7 +33,7 @@ def squeeze(x, /, axis):
33
33
from ._array_module import _UndefinedStub
34
34
from ._array_module import mod as xp
35
35
from .stubs import array_methods , category_to_funcs , extension_to_funcs
36
- from .typing import DataType
36
+ from .typing import Array , DataType
37
37
38
38
pytestmark = pytest .mark .ci
39
39
@@ -112,7 +112,8 @@ def make_pretty_func(func_name: str, args: Sequence[Any], kwargs: Dict[str, Any]
112
112
113
113
114
114
matrixy_funcs : List [FunctionType ] = [
115
- * category_to_funcs ["linear_algebra" ], * extension_to_funcs ["linalg" ]
115
+ * category_to_funcs ["linear_algebra" ],
116
+ * extension_to_funcs ["linalg" ],
116
117
]
117
118
matrixy_names : List [str ] = [f .__name__ for f in matrixy_funcs ]
118
119
matrixy_names += ["__matmul__" , "triu" , "tril" ]
@@ -121,7 +122,7 @@ def make_pretty_func(func_name: str, args: Sequence[Any], kwargs: Dict[str, Any]
121
122
@given (data = st .data ())
122
123
@settings (max_examples = 1 )
123
124
def _test_uninspectable_func (
124
- func_name : str , func : Callable , stub_sig : Signature , data : DataObject
125
+ func_name : str , func : Callable , stub_sig : Signature , array : Array , data : DataObject
125
126
):
126
127
skip_msg = (
127
128
f"Signature for { func_name } () is not inspectable "
@@ -153,12 +154,15 @@ def _test_uninspectable_func(
153
154
value = data .draw (
154
155
xps .arrays (dtype = dtypes , shape = hh .shapes (min_side = 1 )), label = param .name
155
156
)
156
- elif param .name == "x2" :
157
- # sanity check
158
- assert "x1" in [p .name for p in param_to_value .keys ()]
159
- x1 = next (v for p , v in param_to_value .items () if p .name == "x1" )
157
+ elif param .name in ["x2" , "other" ]:
158
+ if param .name == "x2" :
159
+ assert "x1" in [p .name for p in param_to_value .keys ()] # sanity check
160
+ orig = next (v for p , v in param_to_value .items () if p .name == "x1" )
161
+ else :
162
+ assert array is not None # sanity check
163
+ orig = array
160
164
value = data .draw (
161
- xps .arrays (dtype = x1 .dtype , shape = x1 .shape ), label = param .name
165
+ xps .arrays (dtype = orig .dtype , shape = orig .shape ), label = param .name
162
166
)
163
167
else :
164
168
pytest .skip (
@@ -177,11 +181,11 @@ def _test_uninspectable_func(
177
181
func (* args , ** kwargs )
178
182
179
183
180
- def _test_func_signature (
181
- func : Callable , stub : FunctionType , ignore_first_stub_param : bool = False
182
- ):
184
+ def _test_func_signature (func : Callable , stub : FunctionType , array = None ):
183
185
stub_sig = signature (stub )
184
- if ignore_first_stub_param :
186
+ # If testing against array, ignore 'self' arg in stub as it won't be present
187
+ # in func (which should be an array method).
188
+ if array is not None :
185
189
stub_params = list (stub_sig .parameters .values ())
186
190
del stub_params [0 ]
187
191
stub_sig = Signature (
@@ -192,7 +196,7 @@ def _test_func_signature(
192
196
sig = signature (func )
193
197
_test_inspectable_func (sig , stub_sig )
194
198
except ValueError :
195
- _test_uninspectable_func (stub .__name__ , func , stub_sig )
199
+ _test_uninspectable_func (stub .__name__ , func , stub_sig , array )
196
200
197
201
198
202
@pytest .mark .parametrize (
@@ -233,5 +237,4 @@ def test_array_method_signature(stub: FunctionType, data: DataObject):
233
237
x = data .draw (xps .arrays (dtype = dtypes , shape = hh .shapes (min_side = 1 )), label = "x" )
234
238
assert hasattr (x , stub .__name__ ), f"{ stub .__name__ } not found in array object { x !r} "
235
239
method = getattr (x , stub .__name__ )
236
- # Ignore 'self' arg in stub, which won't be present in instantiated objects.
237
- _test_func_signature (method , stub , ignore_first_stub_param = True )
240
+ _test_func_signature (method , stub , array = x )
0 commit comments