Skip to content

Commit af2e299

Browse files
authored
Merge pull request #176 from asmeurer/pytest-helpers-keyword-only
Use keyword-only arguments for the pytest helpers functions
2 parents a533680 + a3505c5 commit af2e299

14 files changed

+369
-305
lines changed

array_api_tests/meta/test_pytest_helpers.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@
55

66

77
def test_assert_dtype():
8-
ph.assert_dtype("promoted_func", [xp.uint8, xp.int8], xp.int16)
8+
ph.assert_dtype("promoted_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.int16)
99
with raises(AssertionError):
10-
ph.assert_dtype("bad_func", [xp.uint8, xp.int8], xp.float32)
11-
ph.assert_dtype("bool_func", [xp.uint8, xp.int8], xp.bool, xp.bool)
12-
ph.assert_dtype("single_promoted_func", [xp.uint8], xp.uint8)
13-
ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool)
10+
ph.assert_dtype("bad_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.float32)
11+
ph.assert_dtype("bool_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.bool, expected=xp.bool)
12+
ph.assert_dtype("single_promoted_func", in_dtype=[xp.uint8], out_dtype=xp.uint8)
13+
ph.assert_dtype("single_bool_func", in_dtype=[xp.uint8], out_dtype=xp.bool, expected=xp.bool)
1414

1515

1616
def test_assert_array_elements():
17-
ph.assert_array_elements("int zeros", xp.asarray(0), xp.asarray(0))
18-
ph.assert_array_elements("pos zeros", xp.asarray(0.0), xp.asarray(0.0))
17+
ph.assert_array_elements("int zeros", out=xp.asarray(0), expected=xp.asarray(0))
18+
ph.assert_array_elements("pos zeros", out=xp.asarray(0.0), expected=xp.asarray(0.0))
1919
with raises(AssertionError):
20-
ph.assert_array_elements("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0))
20+
ph.assert_array_elements("mixed sign zeros", out=xp.asarray(0.0), expected=xp.asarray(-0.0))
2121
with raises(AssertionError):
22-
ph.assert_array_elements("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0))
22+
ph.assert_array_elements("mixed sign zeros", out=xp.asarray(-0.0), expected=xp.asarray(0.0))

array_api_tests/pytest_helpers.py

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ def is_neg_zero(n: float) -> bool:
8282

8383
def assert_dtype(
8484
func_name: str,
85+
*,
8586
in_dtype: Union[DataType, Sequence[DataType]],
8687
out_dtype: DataType,
8788
expected: Optional[DataType] = None,
88-
*,
8989
repr_name: str = "out.dtype",
9090
):
9191
"""
@@ -96,7 +96,7 @@ def assert_dtype(
9696
9797
>>> x = xp.arange(5, dtype=xp.uint8)
9898
>>> out = xp.abs(x)
99-
>>> assert_dtype('abs', x.dtype, out.dtype)
99+
>>> assert_dtype('abs', in_dtype=x.dtype, out_dtype=out.dtype)
100100
101101
is equivalent to
102102
@@ -108,7 +108,7 @@ def assert_dtype(
108108
>>> x1 = xp.arange(5, dtype=xp.uint8)
109109
>>> x2 = xp.arange(5, dtype=xp.uint16)
110110
>>> out = xp.add(x1, x2)
111-
>>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype)
111+
>>> assert_dtype('add', in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
112112
113113
is equivalent to
114114
@@ -119,7 +119,7 @@ def assert_dtype(
119119
>>> x = xp.arange(5, dtype=xp.int8)
120120
>>> out = xp.sum(x)
121121
>>> default_int = xp.asarray(0).dtype
122-
>>> assert_dtype('sum', x, out.dtype, default_int)
122+
>>> assert_dtype('sum', in_dtype=x, out_dtype=out.dtype, expected=default_int)
123123
124124
"""
125125
in_dtypes = in_dtype if isinstance(in_dtype, Sequence) and not isinstance(in_dtype, str) else [in_dtype]
@@ -135,13 +135,18 @@ def assert_dtype(
135135
assert out_dtype == expected, msg
136136

137137

138-
def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
138+
def assert_kw_dtype(
139+
func_name: str,
140+
*,
141+
kw_dtype: DataType,
142+
out_dtype: DataType,
143+
):
139144
"""
140145
Assert the output dtype is the passed keyword dtype, e.g.
141146
142147
>>> kw = {'dtype': xp.uint8}
143-
>>> out = xp.ones(5, **kw)
144-
>>> assert_kw_dtype('ones', kw['dtype'], out.dtype)
148+
>>> out = xp.ones(5, kw=kw)
149+
>>> assert_kw_dtype('ones', kw_dtype=kw['dtype'], out_dtype=out.dtype)
145150
146151
"""
147152
f_kw_dtype = dh.dtype_to_name[kw_dtype]
@@ -222,17 +227,17 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty
222227

223228
def assert_shape(
224229
func_name: str,
230+
*,
225231
out_shape: Union[int, Shape],
226232
expected: Union[int, Shape],
227-
/,
228233
repr_name="out.shape",
229-
**kw,
234+
kw: dict = {},
230235
):
231236
"""
232237
Assert the output shape is as expected, e.g.
233238
234239
>>> out = xp.ones((3, 3, 3))
235-
>>> assert_shape('ones', out.shape, (3, 3, 3))
240+
>>> assert_shape('ones', out_shape=out.shape, expected=(3, 3, 3))
236241
237242
"""
238243
if isinstance(out_shape, int):
@@ -249,11 +254,10 @@ def assert_result_shape(
249254
func_name: str,
250255
in_shapes: Sequence[Shape],
251256
out_shape: Shape,
252-
/,
253257
expected: Optional[Shape] = None,
254258
*,
255259
repr_name="out.shape",
256-
**kw,
260+
kw: dict = {},
257261
):
258262
"""
259263
Assert the output shape is as expected.
@@ -262,7 +266,7 @@ def assert_result_shape(
262266
in_shapes, to test against out_shape, e.g.
263267
264268
>>> out = xp.add(xp.ones((3, 1)), xp.ones((1, 3)))
265-
>>> assert_shape('add', [(3, 1), (1, 3)], out.shape)
269+
>>> assert_result_shape('add', in_shape=[(3, 1), (1, 3)], out_shape=out.shape)
266270
267271
is equivalent to
268272
@@ -281,21 +285,21 @@ def assert_result_shape(
281285

282286
def assert_keepdimable_shape(
283287
func_name: str,
288+
*,
284289
in_shape: Shape,
285290
out_shape: Shape,
286291
axes: Tuple[int, ...],
287292
keepdims: bool,
288-
/,
289-
**kw,
293+
kw: dict = {},
290294
):
291295
"""
292296
Assert the output shape from a keepdimable function is as expected, e.g.
293297
294298
>>> x = xp.asarray([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
295299
>>> out1 = xp.max(x, keepdims=False)
296300
>>> out2 = xp.max(x, keepdims=True)
297-
>>> assert_keepdimable_shape('max', x.shape, out1.shape, (0, 1), False)
298-
>>> assert_keepdimable_shape('max', x.shape, out2.shape, (0, 1), True)
301+
>>> assert_keepdimable_shape('max', in_shape=x.shape, out_shape=out1.shape, axes=(0, 1), keepdims=False)
302+
>>> assert_keepdimable_shape('max', in_shape=x.shape, out_shape=out2.shape, axes=(0, 1), keepdims=True)
299303
300304
is equivalent to
301305
@@ -307,19 +311,26 @@ def assert_keepdimable_shape(
307311
shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape))
308312
else:
309313
shape = tuple(side for axis, side in enumerate(in_shape) if axis not in axes)
310-
assert_shape(func_name, out_shape, shape, **kw)
314+
assert_shape(func_name, out_shape=out_shape, expected=shape, kw=kw)
311315

312316

313317
def assert_0d_equals(
314-
func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw
318+
func_name: str,
319+
*,
320+
x_repr: str,
321+
x_val: Array,
322+
out_repr: str,
323+
out_val: Array,
324+
kw: dict = {},
315325
):
316326
"""
317327
Assert a 0d array is as expected, e.g.
318328
319329
>>> x = xp.asarray([0, 1, 2])
320-
>>> res = xp.asarray(x, copy=True)
330+
>>> kw = {'copy': True}
331+
>>> res = xp.asarray(x, **kw)
321332
>>> res[0] = 42
322-
>>> assert_0d_equals('asarray', 'x[0]', x[0], 'x[0]', res[0])
333+
>>> assert_0d_equals('asarray', x_repr='x[0]', x_val=x[0], out_repr='x[0]', out_val=res[0], kw=kw)
323334
324335
is equivalent to
325336
@@ -338,20 +349,20 @@ def assert_0d_equals(
338349

339350
def assert_scalar_equals(
340351
func_name: str,
352+
*,
341353
type_: ScalarType,
342354
idx: Shape,
343355
out: Scalar,
344356
expected: Scalar,
345-
/,
346357
repr_name: str = "out",
347-
**kw,
358+
kw: dict = {},
348359
):
349360
"""
350361
Assert a 0d array, convered to a scalar, is as expected, e.g.
351362
352363
>>> x = xp.ones(5, dtype=xp.uint8)
353364
>>> out = xp.sum(x)
354-
>>> assert_scalar_equals('sum', int, (), int(out), 5)
365+
>>> assert_scalar_equals('sum', type_int, out=(), out=int(out), expected=5)
355366
356367
is equivalent to
357368
@@ -372,13 +383,18 @@ def assert_scalar_equals(
372383

373384

374385
def assert_fill(
375-
func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw
386+
func_name: str,
387+
*,
388+
fill_value: Scalar,
389+
dtype: DataType,
390+
out: Array,
391+
kw: dict = {},
376392
):
377393
"""
378394
Assert all elements of an array is as expected, e.g.
379395
380396
>>> out = xp.full(5, 42, dtype=xp.uint8)
381-
>>> assert_fill('full', 42, xp.uint8, out, 5)
397+
>>> assert_fill('full', fill_value=42, dtype=xp.uint8, out=out, kw=dict(shape=5))
382398
383399
is equivalent to
384400
@@ -408,22 +424,27 @@ def _assert_float_element(at_out: Array, at_expected: Array, msg: str):
408424

409425

410426
def assert_array_elements(
411-
func_name: str, out: Array, expected: Array, /, *, out_repr: str = "out", **kw
427+
func_name: str,
428+
*,
429+
out: Array,
430+
expected: Array,
431+
out_repr: str = "out",
432+
kw: dict = {},
412433
):
413434
"""
414435
Assert array elements are (strictly) as expected, e.g.
415436
416437
>>> x = xp.arange(5)
417438
>>> out = xp.asarray(x)
418-
>>> assert_array_elements('asarray', out, x)
439+
>>> assert_array_elements('asarray', out=out, expected=x)
419440
420441
is equivalent to
421442
422443
>>> assert xp.all(out == x)
423444
424445
"""
425446
dh.result_type(out.dtype, expected.dtype) # sanity check
426-
assert_shape(func_name, out.shape, expected.shape, **kw) # sanity check
447+
assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
427448
f_func = f"[{func_name}({fmt_kw(kw)})]"
428449
if out.dtype in dh.float_dtypes:
429450
for idx in sh.ndindex(out.shape):

0 commit comments

Comments
 (0)