Skip to content

Commit 7b5e3ab

Browse files
committed
Document remaining pytest helpers
1 parent de35948 commit 7b5e3ab

File tree

1 file changed

+49
-1
lines changed

1 file changed

+49
-1
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,19 @@ def assert_keepdimable_shape(
296296
def assert_0d_equals(
297297
func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw
298298
):
299+
"""
300+
Assert a 0d array is as expected, e.g.
301+
302+
>>> x = xp.asarray([0, 1, 2])
303+
>>> res = xp.asarray(x, copy=True)
304+
>>> res[0] = 42
305+
>>> assert_0d_equals('__setitem__', 'x[0]', x[0], 'x[0]', res[0])
306+
307+
is equivalent to
308+
309+
>>> assert res[0] == x[0]
310+
311+
"""
299312
msg = (
300313
f"{out_repr}={out_val}, but should be {x_repr}={x_val} "
301314
f"[{func_name}({fmt_kw(kw)})]"
@@ -316,9 +329,21 @@ def assert_scalar_equals(
316329
repr_name: str = "out",
317330
**kw,
318331
):
332+
"""
333+
Assert a 0d array, convered to a scalar, is as expected, e.g.
334+
335+
>>> x = xp.ones(5, dtype=xp.uint8)
336+
>>> out = xp.sum(x)
337+
>>> assert_scalar_equals('sum', int, (), int(out), 5)
338+
339+
is equivalent to
340+
341+
>>> assert int(out) == 5
342+
343+
"""
319344
repr_name = repr_name if idx == () else f"{repr_name}[{idx}]"
320345
f_func = f"{func_name}({fmt_kw(kw)})"
321-
if type_ is bool or type_ is int:
346+
if type_ in [bool, int]:
322347
msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
323348
assert out == expected, msg
324349
elif math.isnan(expected):
@@ -332,6 +357,17 @@ def assert_scalar_equals(
332357
def assert_fill(
333358
func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw
334359
):
360+
"""
361+
Assert all elements of an array is as expected, e.g.
362+
363+
>>> out = xp.full(5, 42, dtype=xp.uint8)
364+
>>> assert_fill('full', 42, xp.uint8, out, 5)
365+
366+
is equivalent to
367+
368+
>>> assert xp.all(out == 42)
369+
370+
"""
335371
msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}"
336372
if math.isnan(fill_value):
337373
assert ah.all(ah.isnan(out)), msg
@@ -340,6 +376,18 @@ def assert_fill(
340376

341377

342378
def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
379+
"""
380+
Assert array is (strictly) as expected, e.g.
381+
382+
>>> x = xp.arange(5)
383+
>>> out = xp.asarray(x)
384+
>>> assert_array('asarray', out, x)
385+
386+
is equivalent to
387+
388+
>>> assert xp.all(out == x)
389+
390+
"""
343391
assert_dtype(func_name, out.dtype, expected.dtype)
344392
assert_shape(func_name, out.shape, expected.shape, **kw)
345393
f_func = f"[{func_name}({fmt_kw(kw)})]"

0 commit comments

Comments
 (0)