@@ -20,22 +20,18 @@ def squeeze(x, /, axis):
20
20
...
21
21
22
22
"""
23
+ from collections import defaultdict
24
+ from copy import copy
23
25
from inspect import Parameter , Signature , signature
24
26
from types import FunctionType
25
- from typing import Any , Callable , Dict , List , Literal , get_args
27
+ from typing import Any , Callable , Dict , Literal , get_args
28
+ from warnings import warn
26
29
27
30
import pytest
28
- from hypothesis import given , note , settings
29
- from hypothesis import strategies as st
30
- from hypothesis .strategies import DataObject
31
31
32
32
from . import dtype_helpers as dh
33
- from . import hypothesis_helpers as hh
34
- from . import xps
35
- from ._array_module import _UndefinedStub
36
33
from ._array_module import mod as xp
37
- from .stubs import array_methods , category_to_funcs , extension_to_funcs
38
- from .typing import Array , DataType
34
+ from .stubs import array_methods , category_to_funcs , extension_to_funcs , name_to_func
39
35
40
36
pytestmark = pytest .mark .ci
41
37
@@ -93,24 +89,15 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature):
93
89
stub_param .name in sig .parameters .keys ()
94
90
), f"Argument '{ stub_param .name } ' missing from signature"
95
91
param = next (p for p in params if p .name == stub_param .name )
92
+ f_stub_kind = kind_to_str [stub_param .kind ]
96
93
assert param .kind in [stub_param .kind , Parameter .POSITIONAL_OR_KEYWORD ,], (
97
94
f"{ param .name } is a { kind_to_str [param .kind ]} , "
98
95
f"but should be a { f_stub_kind } "
99
96
f"(or at least a { kind_to_str [ParameterKind .POSITIONAL_OR_KEYWORD ]} )"
100
97
)
101
98
102
99
103
- def get_dtypes_strategy (func_name : str ) -> st .SearchStrategy [DataType ]:
104
- if func_name in dh .func_in_dtypes .keys ():
105
- dtypes = dh .func_in_dtypes [func_name ]
106
- if hh .FILTER_UNDEFINED_DTYPES :
107
- dtypes = [d for d in dtypes if not isinstance (d , _UndefinedStub )]
108
- return st .sampled_from (dtypes )
109
- else :
110
- return xps .scalar_dtypes ()
111
-
112
-
113
- def make_pretty_func (func_name : str , * args : Any , ** kwargs : Any ):
100
+ def make_pretty_func (func_name : str , * args : Any , ** kwargs : Any ) -> str :
114
101
f_sig = f"{ func_name } ("
115
102
f_sig += ", " .join (str (a ) for a in args )
116
103
if len (kwargs ) != 0 :
@@ -121,96 +108,165 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any):
121
108
return f_sig
122
109
123
110
124
- matrixy_funcs : List [FunctionType ] = [
125
- * category_to_funcs ["linear_algebra" ],
126
- * extension_to_funcs ["linalg" ],
111
+ # We test uninspectable signatures by passing valid, manually-defined arguments
112
+ # to the signature's function/method.
113
+ #
114
+ # Arguments which require use of the array module are specified as string
115
+ # expressions to be eval()'d on runtime. This is as opposed to just using the
116
+ # array module whilst setting up the tests, which is prone to halt the entire
117
+ # test suite if an array module doesn't support a given expression.
118
+ func_to_specified_args = defaultdict (
119
+ dict ,
120
+ {
121
+ "permute_dims" : {"axes" : 0 },
122
+ "reshape" : {"shape" : (1 , 5 )},
123
+ "broadcast_to" : {"shape" : (1 , 5 )},
124
+ "asarray" : {"obj" : [0 , 1 , 2 , 3 , 4 ]},
125
+ "full_like" : {"fill_value" : 42 },
126
+ "matrix_power" : {"n" : 2 },
127
+ },
128
+ )
129
+ func_to_specified_arg_exprs = defaultdict (
130
+ dict ,
131
+ {
132
+ "stack" : {"arrays" : "[xp.ones((5,)), xp.ones((5,))]" },
133
+ "iinfo" : {"type" : "xp.int64" },
134
+ "finfo" : {"type" : "xp.float64" },
135
+ "cholesky" : {"x" : "xp.asarray([[1, 0], [0, 1]], dtype=xp.float64)" },
136
+ "inv" : {"x" : "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)" },
137
+ "solve" : {
138
+ a : "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)" for a in ["x1" , "x2" ]
139
+ },
140
+ },
141
+ )
142
+ # We default most array arguments heuristically. As functions/methods work only
143
+ # with arrays of certain dtypes and shapes, we specify only supported arrays
144
+ # respective to the function.
145
+ casty_names = ["__bool__" , "__int__" , "__float__" , "__complex__" , "__index__" ]
146
+ matrixy_names = [
147
+ f .__name__
148
+ for f in category_to_funcs ["linear_algebra" ] + extension_to_funcs ["linalg" ]
127
149
]
128
- matrixy_names : List [str ] = [f .__name__ for f in matrixy_funcs ]
129
150
matrixy_names += ["__matmul__" , "triu" , "tril" ]
151
+ for func_name , func in name_to_func .items ():
152
+ stub_sig = signature (func )
153
+ array_argnames = set (stub_sig .parameters .keys ()) & {"x" , "x1" , "x2" , "other" }
154
+ if func in array_methods :
155
+ array_argnames .add ("self" )
156
+ array_argnames -= set (func_to_specified_arg_exprs [func_name ].keys ())
157
+ if len (array_argnames ) > 0 :
158
+ in_dtypes = dh .func_in_dtypes [func_name ]
159
+ for dtype_name in ["float64" , "bool" , "int64" , "complex128" ]:
160
+ # We try float64 first because uninspectable numerical functions
161
+ # tend to support float inputs first-and-foremost (i.e. PyTorch)
162
+ try :
163
+ dtype = getattr (xp , dtype_name )
164
+ except AttributeError :
165
+ pass
166
+ else :
167
+ if dtype in in_dtypes :
168
+ if func_name in casty_names :
169
+ shape = ()
170
+ elif func_name in matrixy_names :
171
+ shape = (3 , 3 )
172
+ else :
173
+ shape = (5 ,)
174
+ fallback_array_expr = f"xp.ones({ shape } , dtype=xp.{ dtype_name } )"
175
+ break
176
+ else :
177
+ warn (
178
+ f"{ dh .func_in_dtypes ['{func_name}' ]} ={ in_dtypes } seemingly does "
179
+ "not contain any assumed dtypes, so skipping specifying fallback array."
180
+ )
181
+ continue
182
+ for argname in array_argnames :
183
+ func_to_specified_arg_exprs [func_name ][argname ] = fallback_array_expr
184
+
130
185
186
+ def _test_uninspectable_func (func_name : str , func : Callable , stub_sig : Signature ):
187
+ params = list (stub_sig .parameters .values ())
131
188
132
- @given (data = st .data ())
133
- @settings (max_examples = 1 )
134
- def _test_uninspectable_func (
135
- func_name : str , func : Callable , stub_sig : Signature , array : Array , data : DataObject
136
- ):
137
- skip_msg = (
138
- f"Signature for { func_name } () is not inspectable "
139
- "and is too troublesome to test for otherwise"
189
+ if len (params ) == 0 :
190
+ func ()
191
+ return
192
+
193
+ uninspectable_msg = (
194
+ f"Note { func_name } () is not inspectable so arguments are passed "
195
+ "manually to test the signature."
140
196
)
141
- if func_name in [
142
- # 0d shapes
143
- "__bool__" ,
144
- "__int__" ,
145
- "__index__" ,
146
- "__float__" ,
147
- # x2 elements must be >=0
148
- "pow" ,
149
- "bitwise_left_shift" ,
150
- "bitwise_right_shift" ,
151
- # axis default invalid with 0d shapes
152
- "sort" ,
153
- # shape requirements
154
- * matrixy_names ,
155
- ]:
156
- pytest .skip (skip_msg )
157
-
158
- param_to_value : Dict [Parameter , Any ] = {}
159
- for param in stub_sig .parameters .values ():
160
- if param .kind in [Parameter .POSITIONAL_OR_KEYWORD , * VAR_KINDS ]:
197
+
198
+ argname_to_arg = copy (func_to_specified_args [func_name ])
199
+ argname_to_expr = func_to_specified_arg_exprs [func_name ]
200
+ for argname , expr in argname_to_expr .items ():
201
+ assert argname not in argname_to_arg .keys () # sanity check
202
+ try :
203
+ argname_to_arg [argname ] = eval (expr , {"xp" : xp })
204
+ except Exception as e :
161
205
pytest .skip (
162
- skip_msg + f" (because '{ param .name } ' is a { kind_to_str [param .kind ]} )"
163
- )
164
- elif param .default != Parameter .empty :
165
- value = param .default
166
- elif param .name in ["x" , "x1" ]:
167
- dtypes = get_dtypes_strategy (func_name )
168
- value = data .draw (
169
- xps .arrays (dtype = dtypes , shape = hh .shapes (min_side = 1 )), label = param .name
206
+ f"Exception occured when evaluating { argname } ={ expr } : { e } \n "
207
+ f"{ uninspectable_msg } "
170
208
)
171
- elif param .name in ["x2" , "other" ]:
172
- if param .name == "x2" :
173
- assert "x1" in [p .name for p in param_to_value .keys ()] # sanity check
174
- orig = next (v for p , v in param_to_value .items () if p .name == "x1" )
209
+
210
+ posargs = []
211
+ posorkw_args = {}
212
+ kwargs = {}
213
+ no_arg_msg = (
214
+ "We have no argument specified for '{}'. Please ensure you're using "
215
+ "the latest version of array-api-tests, then open an issue if one "
216
+ f"doesn't already exist. { uninspectable_msg } "
217
+ )
218
+ for param in params :
219
+ if param .kind == Parameter .POSITIONAL_ONLY :
220
+ try :
221
+ posargs .append (argname_to_arg [param .name ])
222
+ except KeyError :
223
+ pytest .skip (no_arg_msg .format (param .name ))
224
+ elif param .kind == Parameter .POSITIONAL_OR_KEYWORD :
225
+ if param .default == Parameter .empty :
226
+ try :
227
+ posorkw_args [param .name ] = argname_to_arg [param .name ]
228
+ except KeyError :
229
+ pytest .skip (no_arg_msg .format (param .name ))
175
230
else :
176
- assert array is not None # sanity check
177
- orig = array
178
- value = data . draw (
179
- xps . arrays ( dtype = orig . dtype , shape = orig . shape ), label = param . name
180
- )
231
+ assert argname_to_arg [ param . name ]
232
+ posorkw_args [ param . name ] = param . default
233
+ elif param . kind == Parameter . KEYWORD_ONLY :
234
+ assert param . default != Parameter . empty # sanity check
235
+ kwargs [ param . name ] = param . default
181
236
else :
182
- pytest .skip (
183
- skip_msg + f" (because no default was found for argument { param .name } )"
184
- )
185
- param_to_value [param ] = value
186
-
187
- args : List [Any ] = [
188
- v for p , v in param_to_value .items () if p .kind == Parameter .POSITIONAL_ONLY
189
- ]
190
- kwargs : Dict [str , Any ] = {
191
- p .name : v for p , v in param_to_value .items () if p .kind == Parameter .KEYWORD_ONLY
192
- }
193
- f_func = make_pretty_func (func_name , * args , ** kwargs )
194
- note (f"trying { f_func } " )
195
- func (* args , ** kwargs )
237
+ assert param .kind in VAR_KINDS # sanity check
238
+ pytest .skip (no_arg_msg .format (param .name ))
239
+ if len (posorkw_args ) == 0 :
240
+ func (* posargs , ** kwargs )
241
+ else :
242
+ posorkw_name_to_arg_pairs = list (posorkw_args .items ())
243
+ for i in range (len (posorkw_name_to_arg_pairs ), - 1 , - 1 ):
244
+ extra_posargs = [arg for _ , arg in posorkw_name_to_arg_pairs [:i ]]
245
+ extra_kwargs = dict (posorkw_name_to_arg_pairs [i :])
246
+ func (* posargs , * extra_posargs , ** kwargs , ** extra_kwargs )
196
247
197
248
198
- def _test_func_signature (func : Callable , stub : FunctionType , array = None ):
249
+ def _test_func_signature (func : Callable , stub : FunctionType , is_method = False ):
199
250
stub_sig = signature (stub )
200
251
# If testing against array, ignore 'self' arg in stub as it won't be present
201
252
# in func (which should be a method).
202
- if array is not None :
253
+ if is_method :
203
254
stub_params = list (stub_sig .parameters .values ())
204
- del stub_params [0 ]
255
+ if stub_params [0 ].name == "self" :
256
+ del stub_params [0 ]
205
257
stub_sig = Signature (
206
258
parameters = stub_params , return_annotation = stub_sig .return_annotation
207
259
)
208
260
209
261
try :
210
262
sig = signature (func )
211
- _test_inspectable_func (sig , stub_sig )
212
263
except ValueError :
213
- _test_uninspectable_func (stub .__name__ , func , stub_sig , array )
264
+ try :
265
+ _test_uninspectable_func (stub .__name__ , func , stub_sig )
266
+ except Exception as e :
267
+ raise e from None # suppress parent exception for cleaner pytest output
268
+ else :
269
+ _test_inspectable_func (sig , stub_sig )
214
270
215
271
216
272
@pytest .mark .parametrize (
@@ -244,11 +300,12 @@ def test_extension_func_signature(extension: str, stub: FunctionType):
244
300
245
301
246
302
@pytest .mark .parametrize ("stub" , array_methods , ids = lambda f : f .__name__ )
247
- @given (st .data ())
248
- @settings (max_examples = 1 )
249
- def test_array_method_signature (stub : FunctionType , data : DataObject ):
250
- dtypes = get_dtypes_strategy (stub .__name__ )
251
- x = data .draw (xps .arrays (dtype = dtypes , shape = hh .shapes (min_side = 1 )), label = "x" )
303
+ def test_array_method_signature (stub : FunctionType ):
304
+ x_expr = func_to_specified_arg_exprs [stub .__name__ ]["self" ]
305
+ try :
306
+ x = eval (x_expr , {"xp" : xp })
307
+ except Exception as e :
308
+ pytest .skip (f"Exception occured when evaluating x={ x_expr } : { e } " )
252
309
assert hasattr (x , stub .__name__ ), f"{ stub .__name__ } not found in array object { x !r} "
253
310
method = getattr (x , stub .__name__ )
254
- _test_func_signature (method , stub , array = x )
311
+ _test_func_signature (method , stub , is_method = True )
0 commit comments