@@ -20,22 +20,17 @@ def squeeze(x, /, axis):
20
20
...
21
21
22
22
"""
23
+ from collections import defaultdict
23
24
from inspect import Parameter , Signature , signature
24
25
from types import FunctionType
25
- from typing import Any , Callable , Dict , List , Literal , get_args
26
+ from typing import Any , Callable , Dict , Literal , get_args
27
+ from warnings import warn
26
28
27
29
import pytest
28
- from hypothesis import given , note , settings
29
- from hypothesis import strategies as st
30
- from hypothesis .strategies import DataObject
31
30
32
31
from . import dtype_helpers as dh
33
- from . import hypothesis_helpers as hh
34
- from . import xps
35
- from ._array_module import _UndefinedStub
36
32
from ._array_module import mod as xp
37
- from .stubs import array_methods , category_to_funcs , extension_to_funcs
38
- from .typing import Array , DataType
33
+ from .stubs import array_methods , category_to_funcs , extension_to_funcs , name_to_func
39
34
40
35
pytestmark = pytest .mark .ci
41
36
@@ -101,17 +96,7 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature):
101
96
)
102
97
103
98
104
- def get_dtypes_strategy (func_name : str ) -> st .SearchStrategy [DataType ]:
105
- if func_name in dh .func_in_dtypes .keys ():
106
- dtypes = dh .func_in_dtypes [func_name ]
107
- if hh .FILTER_UNDEFINED_DTYPES :
108
- dtypes = [d for d in dtypes if not isinstance (d , _UndefinedStub )]
109
- return st .sampled_from (dtypes )
110
- else :
111
- return xps .scalar_dtypes ()
112
-
113
-
114
- def make_pretty_func (func_name : str , * args : Any , ** kwargs : Any ):
99
+ def make_pretty_func (func_name : str , * args : Any , ** kwargs : Any ) -> str :
115
100
f_sig = f"{ func_name } ("
116
101
f_sig += ", " .join (str (a ) for a in args )
117
102
if len (kwargs ) != 0 :
@@ -122,96 +107,161 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any):
122
107
return f_sig
123
108
124
109
125
- matrixy_funcs : List [FunctionType ] = [
126
- * category_to_funcs ["linear_algebra" ],
127
- * extension_to_funcs ["linalg" ],
110
+ # We test uninspectable signatures by passing valid, manually-defined arguments
111
+ # to the signature's function/method.
112
+ #
113
+ # Arguments which require use of the array module are specified as string
114
+ # expressions to be eval()'d on runtime. This is as opposed to just using the
115
+ # array module whilst setting up the tests, which is prone to halt the entire
116
+ # test suite if an array module doesn't support a given expression.
117
+ func_to_specified_args = defaultdict (
118
+ dict ,
119
+ {
120
+ "permute_dims" : {"axes" : 0 },
121
+ "reshape" : {"shape" : (1 , 5 )},
122
+ "broadcast_to" : {"shape" : (1 , 5 )},
123
+ "asarray" : {"obj" : [0 , 1 , 2 , 3 , 4 ]},
124
+ "full_like" : {"fill_value" : 42 },
125
+ "matrix_power" : {"n" : 2 },
126
+ },
127
+ )
128
+ func_to_specified_arg_exprs = defaultdict (
129
+ dict ,
130
+ {
131
+ "stack" : {"arrays" : "[xp.ones((5,)), xp.ones((5,))]" },
132
+ "iinfo" : {"type" : "xp.int64" },
133
+ "finfo" : {"type" : "xp.float64" },
134
+ "logaddexp" : {a : "xp.ones((5,), dtype=xp.float64)" for a in ["x1" , "x2" ]},
135
+ },
136
+ )
137
+ # We default most array arguments heuristically. As functions/methods work only
138
+ # with arrays of certain dtypes and shapes, we specify only supported arrays
139
+ # respective to the function.
140
+ casty_names = ["__bool__" , "__int__" , "__float__" , "__complex__" , "__index__" ]
141
+ matrixy_names = [
142
+ f .__name__
143
+ for f in category_to_funcs ["linear_algebra" ] + extension_to_funcs ["linalg" ]
128
144
]
129
- matrixy_names : List [str ] = [f .__name__ for f in matrixy_funcs ]
130
145
matrixy_names += ["__matmul__" , "triu" , "tril" ]
146
+ for func_name , func in name_to_func .items ():
147
+ stub_sig = signature (func )
148
+ array_argnames = set (stub_sig .parameters .keys ()) & {"x" , "x1" , "x2" , "other" }
149
+ if func in array_methods :
150
+ array_argnames .add ("self" )
151
+ array_argnames -= set (func_to_specified_arg_exprs [func_name ].keys ())
152
+ if len (array_argnames ) > 0 :
153
+ in_dtypes = dh .func_in_dtypes [func_name ]
154
+ for dtype_name in ["float64" , "bool" , "int64" , "complex128" ]:
155
+ # We try float64 first because uninspectable numerical functions
156
+ # tend to support float inputs first-and-foremost (i.e. PyTorch)
157
+ try :
158
+ dtype = getattr (xp , dtype_name )
159
+ except AttributeError :
160
+ pass
161
+ else :
162
+ if dtype in in_dtypes :
163
+ if func_name in casty_names :
164
+ shape = ()
165
+ elif func_name in matrixy_names :
166
+ shape = (2 , 2 )
167
+ else :
168
+ shape = (5 ,)
169
+ fallback_array_expr = f"xp.ones({ shape } , dtype=xp.{ dtype_name } )"
170
+ break
171
+ else :
172
+ warn (
173
+ f"{ dh .func_in_dtypes ['{func_name}' ]} ={ in_dtypes } seemingly does "
174
+ "not contain any assumed dtypes, so skipping specifying fallback array."
175
+ )
176
+ continue
177
+ for argname in array_argnames :
178
+ func_to_specified_arg_exprs [func_name ][argname ] = fallback_array_expr
179
+
180
+
181
+ def _test_uninspectable_func (func_name : str , func : Callable , stub_sig : Signature ):
182
+ if func_name in matrixy_names :
183
+ pytest .xfail ("TODO" )
131
184
185
+ params = list (stub_sig .parameters .values ())
132
186
133
- @given (data = st .data ())
134
- @settings (max_examples = 1 )
135
- def _test_uninspectable_func (
136
- func_name : str , func : Callable , stub_sig : Signature , array : Array , data : DataObject
137
- ):
138
- skip_msg = (
139
- f"Signature for { func_name } () is not inspectable "
140
- "and is too troublesome to test for otherwise"
187
+ if len (params ) == 0 :
188
+ func ()
189
+ return
190
+
191
+ uninspectable_msg = (
192
+ f"Note { func_name } () is not inspectable so arguments are passed "
193
+ "manually to test the signature."
141
194
)
142
- if func_name in [
143
- # 0d shapes
144
- "__bool__" ,
145
- "__int__" ,
146
- "__index__" ,
147
- "__float__" ,
148
- # x2 elements must be >=0
149
- "pow" ,
150
- "bitwise_left_shift" ,
151
- "bitwise_right_shift" ,
152
- # axis default invalid with 0d shapes
153
- "sort" ,
154
- # shape requirements
155
- * matrixy_names ,
156
- ]:
157
- pytest .skip (skip_msg )
158
-
159
- param_to_value : Dict [Parameter , Any ] = {}
160
- for param in stub_sig .parameters .values ():
161
- if param .kind in [Parameter .POSITIONAL_OR_KEYWORD , * VAR_KINDS ]:
195
+
196
+ argname_to_arg = func_to_specified_args [func_name ]
197
+ argname_to_expr = func_to_specified_arg_exprs [func_name ]
198
+ for argname , expr in argname_to_expr .items ():
199
+ assert argname not in argname_to_arg .keys () # sanity check
200
+ try :
201
+ argname_to_arg [argname ] = eval (expr , {"xp" : xp })
202
+ except Exception as e :
162
203
pytest .skip (
163
- skip_msg + f" (because '{ param .name } ' is a { kind_to_str [param .kind ]} )"
164
- )
165
- elif param .default != Parameter .empty :
166
- value = param .default
167
- elif param .name in ["x" , "x1" ]:
168
- dtypes = get_dtypes_strategy (func_name )
169
- value = data .draw (
170
- xps .arrays (dtype = dtypes , shape = hh .shapes (min_side = 1 )), label = param .name
204
+ f"Exception occured when evaluating { argname } ={ expr } : { e } \n "
205
+ f"{ uninspectable_msg } "
171
206
)
172
- elif param .name in ["x2" , "other" ]:
173
- if param .name == "x2" :
174
- assert "x1" in [p .name for p in param_to_value .keys ()] # sanity check
175
- orig = next (v for p , v in param_to_value .items () if p .name == "x1" )
207
+
208
+ posargs = []
209
+ posorkw_args = {}
210
+ kwargs = {}
211
+ no_arg_msg = (
212
+ "We have no argument specified for '{}'. Please ensure you're using "
213
+ "the latest version of array-api-tests, then open an issue if one "
214
+ f"doesn't already exist. { uninspectable_msg } "
215
+ )
216
+ for param in params :
217
+ if param .kind == Parameter .POSITIONAL_ONLY :
218
+ try :
219
+ posargs .append (argname_to_arg [param .name ])
220
+ except KeyError :
221
+ pytest .skip (no_arg_msg .format (param .name ))
222
+ elif param .kind == Parameter .POSITIONAL_OR_KEYWORD :
223
+ if param .default == Parameter .empty :
224
+ try :
225
+ posorkw_args [param .name ] = argname_to_arg [param .name ]
226
+ except KeyError :
227
+ pytest .skip (no_arg_msg .format (param .name ))
176
228
else :
177
- assert array is not None # sanity check
178
- orig = array
179
- value = data . draw (
180
- xps . arrays ( dtype = orig . dtype , shape = orig . shape ), label = param . name
181
- )
229
+ assert argname_to_arg [ param . name ]
230
+ posorkw_args [ param . name ] = param . default
231
+ elif param . kind == Parameter . KEYWORD_ONLY :
232
+ assert param . default != Parameter . empty # sanity check
233
+ kwargs [ param . name ] = param . default
182
234
else :
183
- pytest .skip (
184
- skip_msg + f" (because no default was found for argument { param .name } )"
185
- )
186
- param_to_value [param ] = value
187
-
188
- args : List [Any ] = [
189
- v for p , v in param_to_value .items () if p .kind == Parameter .POSITIONAL_ONLY
190
- ]
191
- kwargs : Dict [str , Any ] = {
192
- p .name : v for p , v in param_to_value .items () if p .kind == Parameter .KEYWORD_ONLY
193
- }
194
- f_func = make_pretty_func (func_name , * args , ** kwargs )
195
- note (f"trying { f_func } " )
196
- func (* args , ** kwargs )
235
+ assert param .kind in VAR_KINDS # sanity check
236
+ pytest .skip (no_arg_msg .format (param .name ))
237
+ if len (posorkw_args ) == 0 :
238
+ func (* posargs , ** kwargs )
239
+ else :
240
+ func (* posargs , ** posorkw_args , ** kwargs )
241
+ # TODO: test all positional and keyword permutations of pos-or-kw args
197
242
198
243
199
- def _test_func_signature (func : Callable , stub : FunctionType , array = None ):
244
+ def _test_func_signature (func : Callable , stub : FunctionType , is_method = False ):
200
245
stub_sig = signature (stub )
201
246
# If testing against array, ignore 'self' arg in stub as it won't be present
202
247
# in func (which should be a method).
203
- if array is not None :
248
+ if is_method :
204
249
stub_params = list (stub_sig .parameters .values ())
205
- del stub_params [0 ]
250
+ if stub_params [0 ].name == "self" :
251
+ del stub_params [0 ]
206
252
stub_sig = Signature (
207
253
parameters = stub_params , return_annotation = stub_sig .return_annotation
208
254
)
209
255
210
256
try :
211
257
sig = signature (func )
212
- _test_inspectable_func (sig , stub_sig )
213
258
except ValueError :
214
- _test_uninspectable_func (stub .__name__ , func , stub_sig , array )
259
+ try :
260
+ _test_uninspectable_func (stub .__name__ , func , stub_sig )
261
+ except Exception as e :
262
+ raise e from None # suppress parent exception for cleaner pytest output
263
+ else :
264
+ _test_inspectable_func (sig , stub_sig )
215
265
216
266
217
267
@pytest .mark .parametrize (
@@ -245,11 +295,12 @@ def test_extension_func_signature(extension: str, stub: FunctionType):
245
295
246
296
247
297
@pytest .mark .parametrize ("stub" , array_methods , ids = lambda f : f .__name__ )
248
- @given (st .data ())
249
- @settings (max_examples = 1 )
250
- def test_array_method_signature (stub : FunctionType , data : DataObject ):
251
- dtypes = get_dtypes_strategy (stub .__name__ )
252
- x = data .draw (xps .arrays (dtype = dtypes , shape = hh .shapes (min_side = 1 )), label = "x" )
298
+ def test_array_method_signature (stub : FunctionType ):
299
+ x_expr = func_to_specified_arg_exprs [stub .__name__ ]["self" ]
300
+ try :
301
+ x = eval (x_expr , {"xp" : xp })
302
+ except Exception as e :
303
+ pytest .skip (f"Exception occured when evaluating x={ x_expr } : { e } " )
253
304
assert hasattr (x , stub .__name__ ), f"{ stub .__name__ } not found in array object { x !r} "
254
305
method = getattr (x , stub .__name__ )
255
- _test_func_signature (method , stub , array = x )
306
+ _test_func_signature (method , stub , is_method = True )
0 commit comments