Skip to content

Commit fdde7ec

Browse files
authored
Merge pull request #120 from honno/ph-docs
`pytest_helpers.py` documentation
2 parents 37be0aa + b62af64 commit fdde7ec

File tree

7 files changed

+192
-49
lines changed

7 files changed

+192
-49
lines changed

.github/workflows/numpy.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ jobs:
3333
3434
# copy not implemented
3535
array_api_tests/test_creation_functions.py::test_asarray_arrays
36-
# https://github.com/numpy/numpy/issues/18881
37-
array_api_tests/test_creation_functions.py::test_linspace
3836
# https://github.com/numpy/numpy/issues/20870
3937
array_api_tests/test_data_type_functions.py::test_can_cast
4038
# The return dtype for trace is not consistent in the spec

array_api_tests/pytest_helpers.py

Lines changed: 160 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Any, Dict, Optional, Sequence, Tuple, Union
44

55
from . import _array_module as xp
6-
from . import array_helpers as ah
76
from . import dtype_helpers as dh
87
from . import shape_helpers as sh
98
from . import stubs
@@ -88,6 +87,40 @@ def assert_dtype(
8887
*,
8988
repr_name: str = "out.dtype",
9089
):
90+
"""
91+
Assert the output dtype is as expected.
92+
93+
If expected=None, we infer the expected dtype as in_dtype, to test
94+
out_dtype, e.g.
95+
96+
>>> x = xp.arange(5, dtype=xp.uint8)
97+
>>> out = xp.abs(x)
98+
>>> assert_dtype('abs', x.dtype, out.dtype)
99+
100+
is equivalent to
101+
102+
>>> assert out.dtype == xp.uint8
103+
104+
Or for multiple input dtypes, the expected dtype is inferred from their
105+
resulting type promotion, e.g.
106+
107+
>>> x1 = xp.arange(5, dtype=xp.uint8)
108+
>>> x2 = xp.arange(5, dtype=xp.uint16)
109+
>>> out = xp.add(x1, x2)
110+
>>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype)
111+
112+
is equivalent to
113+
114+
>>> assert out.dtype == xp.uint16
115+
116+
We can also specify the expected dtype ourselves, e.g.
117+
118+
>>> x = xp.arange(5, dtype=xp.int8)
119+
>>> out = xp.sum(x)
120+
>>> default_int = xp.asarray(0).dtype
121+
>>> assert_dtype('sum', x, out.dtype, default_int)
122+
123+
"""
91124
in_dtypes = in_dtype if isinstance(in_dtype, Sequence) else [in_dtype]
92125
f_in_dtypes = dh.fmt_types(tuple(in_dtypes))
93126
f_out_dtype = dh.dtype_to_name[out_dtype]
@@ -102,6 +135,14 @@ def assert_dtype(
102135

103136

104137
def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
138+
"""
139+
Assert the output dtype is the passed keyword dtype, e.g.
140+
141+
>>> kw = {'dtype': xp.uint8}
142+
>>> out = xp.ones(5, **kw)
143+
>>> assert_kw_dtype('ones', kw['dtype'], out.dtype)
144+
145+
"""
105146
f_kw_dtype = dh.dtype_to_name[kw_dtype]
106147
f_out_dtype = dh.dtype_to_name[out_dtype]
107148
msg = (
@@ -111,33 +152,54 @@ def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
111152
assert out_dtype == kw_dtype, msg
112153

113154

114-
def assert_default_float(func_name: str, dtype: DataType):
115-
f_dtype = dh.dtype_to_name[dtype]
155+
def assert_default_float(func_name: str, out_dtype: DataType):
156+
"""
157+
Assert the output dtype is the default float, e.g.
158+
159+
>>> out = xp.ones(5)
160+
>>> assert_default_float('ones', out.dtype)
161+
162+
"""
163+
f_dtype = dh.dtype_to_name[out_dtype]
116164
f_default = dh.dtype_to_name[dh.default_float]
117165
msg = (
118166
f"out.dtype={f_dtype}, should be default "
119167
f"floating-point dtype {f_default} [{func_name}()]"
120168
)
121-
assert dtype == dh.default_float, msg
169+
assert out_dtype == dh.default_float, msg
122170

123171

124-
def assert_default_int(func_name: str, dtype: DataType):
125-
f_dtype = dh.dtype_to_name[dtype]
172+
def assert_default_int(func_name: str, out_dtype: DataType):
173+
"""
174+
Assert the output dtype is the default int, e.g.
175+
176+
>>> out = xp.full(5, 42)
177+
>>> assert_default_int('full', out.dtype)
178+
179+
"""
180+
f_dtype = dh.dtype_to_name[out_dtype]
126181
f_default = dh.dtype_to_name[dh.default_int]
127182
msg = (
128183
f"out.dtype={f_dtype}, should be default "
129184
f"integer dtype {f_default} [{func_name}()]"
130185
)
131-
assert dtype == dh.default_int, msg
186+
assert out_dtype == dh.default_int, msg
187+
188+
189+
def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dtype"):
190+
"""
191+
Assert the output dtype is the default index dtype, e.g.
132192
193+
>>> out = xp.argmax(xp.arange(5))
194+
>>> assert_default_int('argmax', out.dtype)
133195
134-
def assert_default_index(func_name: str, dtype: DataType, repr_name="out.dtype"):
135-
f_dtype = dh.dtype_to_name[dtype]
196+
"""
197+
f_dtype = dh.dtype_to_name[out_dtype]
136198
msg = (
137199
f"{repr_name}={f_dtype}, should be the default index dtype, "
138200
f"which is either int32 or int64 [{func_name}()]"
139201
)
140-
assert dtype in (xp.int32, xp.int64), msg
202+
assert out_dtype in (xp.int32, xp.int64), msg
141203

142204

143205
def assert_shape(
@@ -148,6 +210,13 @@ def assert_shape(
148210
repr_name="out.shape",
149211
**kw,
150212
):
213+
"""
214+
Assert the output shape is as expected, e.g.
215+
216+
>>> out = xp.ones((3, 3, 3))
217+
>>> assert_shape('ones', out.shape, (3, 3, 3))
218+
219+
"""
151220
if isinstance(out_shape, int):
152221
out_shape = (out_shape,)
153222
if isinstance(expected, int):
@@ -168,6 +237,20 @@ def assert_result_shape(
168237
repr_name="out.shape",
169238
**kw,
170239
):
240+
"""
241+
Assert the output shape is as expected.
242+
243+
If expected=None, we infer the expected shape as the result of broadcasting
244+
in_shapes, to test against out_shape, e.g.
245+
246+
>>> out = xp.add(xp.ones((3, 1)), xp.ones((1, 3)))
247+
>>> assert_shape('add', [(3, 1), (1, 3)], out.shape)
248+
249+
is equivalent to
250+
251+
>>> assert out.shape == (3, 3)
252+
253+
"""
171254
if expected is None:
172255
expected = sh.broadcast_shapes(*in_shapes)
173256
f_in_shapes = " . ".join(str(s) for s in in_shapes)
@@ -180,13 +263,28 @@ def assert_result_shape(
180263

181264
def assert_keepdimable_shape(
182265
func_name: str,
183-
out_shape: Shape,
184266
in_shape: Shape,
267+
out_shape: Shape,
185268
axes: Tuple[int, ...],
186269
keepdims: bool,
187270
/,
188271
**kw,
189272
):
273+
"""
274+
Assert the output shape from a keepdimable function is as expected, e.g.
275+
276+
>>> x = xp.asarray([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
277+
>>> out1 = xp.max(x, keepdims=False)
278+
>>> out2 = xp.max(x, keepdims=True)
279+
>>> assert_keepdimable_shape('max', x.shape, out1.shape, (0, 1), False)
280+
>>> assert_keepdimable_shape('max', x.shape, out2.shape, (0, 1), True)
281+
282+
is equivalent to
283+
284+
>>> assert out1.shape == ()
285+
>>> assert out2.shape == (1, 1)
286+
287+
"""
190288
if keepdims:
191289
shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape))
192290
else:
@@ -197,6 +295,19 @@ def assert_keepdimable_shape(
197295
def assert_0d_equals(
198296
func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw
199297
):
298+
"""
299+
Assert a 0d array is as expected, e.g.
300+
301+
>>> x = xp.asarray([0, 1, 2])
302+
>>> res = xp.asarray(x, copy=True)
303+
>>> res[0] = 42
304+
>>> assert_0d_equals('__setitem__', 'x[0]', x[0], 'x[0]', res[0])
305+
306+
is equivalent to
307+
308+
>>> assert res[0] == x[0]
309+
310+
"""
200311
msg = (
201312
f"{out_repr}={out_val}, but should be {x_repr}={x_val} "
202313
f"[{func_name}({fmt_kw(kw)})]"
@@ -217,9 +328,21 @@ def assert_scalar_equals(
217328
repr_name: str = "out",
218329
**kw,
219330
):
331+
"""
332+
Assert a 0d array, convered to a scalar, is as expected, e.g.
333+
334+
>>> x = xp.ones(5, dtype=xp.uint8)
335+
>>> out = xp.sum(x)
336+
>>> assert_scalar_equals('sum', int, (), int(out), 5)
337+
338+
is equivalent to
339+
340+
>>> assert int(out) == 5
341+
342+
"""
220343
repr_name = repr_name if idx == () else f"{repr_name}[{idx}]"
221344
f_func = f"{func_name}({fmt_kw(kw)})"
222-
if type_ is bool or type_ is int:
345+
if type_ in [bool, int]:
223346
msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
224347
assert out == expected, msg
225348
elif math.isnan(expected):
@@ -233,14 +356,37 @@ def assert_scalar_equals(
233356
def assert_fill(
234357
func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw
235358
):
359+
"""
360+
Assert all elements of an array is as expected, e.g.
361+
362+
>>> out = xp.full(5, 42, dtype=xp.uint8)
363+
>>> assert_fill('full', 42, xp.uint8, out, 5)
364+
365+
is equivalent to
366+
367+
>>> assert xp.all(out == 42)
368+
369+
"""
236370
msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}"
237371
if math.isnan(fill_value):
238-
assert ah.all(ah.isnan(out)), msg
372+
assert xp.all(xp.isnan(out)), msg
239373
else:
240-
assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), msg
374+
assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg
241375

242376

243377
def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
378+
"""
379+
Assert array is (strictly) as expected, e.g.
380+
381+
>>> x = xp.arange(5)
382+
>>> out = xp.asarray(x)
383+
>>> assert_array('asarray', out, x)
384+
385+
is equivalent to
386+
387+
>>> assert xp.all(out == x)
388+
389+
"""
244390
assert_dtype(func_name, out.dtype, expected.dtype)
245391
assert_shape(func_name, out.shape, expected.shape, **kw)
246392
f_func = f"[{func_name}({fmt_kw(kw)})]"

array_api_tests/test_creation_functions.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from hypothesis import strategies as st
88

99
from . import _array_module as xp
10-
from . import array_helpers as ah
1110
from . import dtype_helpers as dh
1211
from . import hypothesis_helpers as hh
1312
from . import pytest_helpers as ph
@@ -181,12 +180,12 @@ def test_arange(dtype, data):
181180
if dh.is_int_dtype(_dtype):
182181
elements = list(r)
183182
assume(out.size == len(elements))
184-
ah.assert_exactly_equal(out, ah.asarray(elements, dtype=_dtype))
183+
ph.assert_array("arange", out, xp.asarray(elements, dtype=_dtype))
185184
else:
186185
assume(out.size == size)
187186
if out.size > 0:
188-
assert ah.equal(
189-
out[0], ah.asarray(_start, dtype=out.dtype)
187+
assert xp.equal(
188+
out[0], xp.asarray(_start, dtype=out.dtype)
190189
), f"out[0]={out[0]}, but should be {_start} {f_func}"
191190

192191

@@ -421,8 +420,8 @@ def test_linspace(num, dtype, endpoint, data):
421420
start = data.draw(xps.from_dtype(_dtype, **finite_kw), label="start")
422421
stop = data.draw(xps.from_dtype(_dtype, **finite_kw), label="stop")
423422
# avoid overflow errors
424-
assume(not ah.isnan(ah.asarray(stop - start, dtype=_dtype)))
425-
assume(not ah.isnan(ah.asarray(start - stop, dtype=_dtype)))
423+
assume(not xp.isnan(xp.asarray(stop - start, dtype=_dtype)))
424+
assume(not xp.isnan(xp.asarray(start - stop, dtype=_dtype)))
426425

427426
kw = data.draw(
428427
hh.specified_kwargs(
@@ -440,20 +439,20 @@ def test_linspace(num, dtype, endpoint, data):
440439
ph.assert_shape("linspace", out.shape, num, start=stop, stop=stop, num=num)
441440
f_func = f"[linspace({start}, {stop}, {num})]"
442441
if num > 0:
443-
assert ah.equal(
444-
out[0], ah.asarray(start, dtype=out.dtype)
442+
assert xp.equal(
443+
out[0], xp.asarray(start, dtype=out.dtype)
445444
), f"out[0]={out[0]}, but should be {start} {f_func}"
446445
if endpoint:
447446
if num > 1:
448-
assert ah.equal(
449-
out[-1], ah.asarray(stop, dtype=out.dtype)
447+
assert xp.equal(
448+
out[-1], xp.asarray(stop, dtype=out.dtype)
450449
), f"out[-1]={out[-1]}, but should be {stop} {f_func}"
451450
else:
452451
# linspace(..., num, endpoint=True) should return an array equivalent to
453452
# the first num elements when endpoint=False
454453
expected = xp.linspace(start, stop, num + 1, dtype=dtype, endpoint=True)
455454
expected = expected[:-1]
456-
ah.assert_exactly_equal(out, expected)
455+
ph.assert_array("linspace", out, expected)
457456

458457

459458
@given(dtype=xps.numeric_dtypes(), data=st.data())

0 commit comments

Comments
 (0)