Skip to content

pytest_helpers.py documentation #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/numpy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ jobs:

# copy not implemented
array_api_tests/test_creation_functions.py::test_asarray_arrays
# https://github.com/numpy/numpy/issues/18881
array_api_tests/test_creation_functions.py::test_linspace
# https://github.com/numpy/numpy/issues/20870
array_api_tests/test_data_type_functions.py::test_can_cast
# The return dtype for trace is not consistent in the spec
Expand Down
174 changes: 160 additions & 14 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Union

from . import _array_module as xp
from . import array_helpers as ah
from . import dtype_helpers as dh
from . import shape_helpers as sh
from . import stubs
Expand Down Expand Up @@ -88,6 +87,40 @@ def assert_dtype(
*,
repr_name: str = "out.dtype",
):
"""
Assert the output dtype is as expected.

If expected=None, we infer the expected dtype as in_dtype, to test
out_dtype, e.g.

>>> x = xp.arange(5, dtype=xp.uint8)
>>> out = xp.abs(x)
>>> assert_dtype('abs', x.dtype, out.dtype)

is equivalent to

>>> assert out.dtype == xp.uint8

Or for multiple input dtypes, the expected dtype is inferred from their
resulting type promotion, e.g.

>>> x1 = xp.arange(5, dtype=xp.uint8)
>>> x2 = xp.arange(5, dtype=xp.uint16)
>>> out = xp.add(x1, x2)
>>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype)

is equivalent to

>>> assert out.dtype == xp.uint16

We can also specify the expected dtype ourselves, e.g.

>>> x = xp.arange(5, dtype=xp.int8)
>>> out = xp.sum(x)
>>> default_int = xp.asarray(0).dtype
>>> assert_dtype('sum', x, out.dtype, default_int)

"""
in_dtypes = in_dtype if isinstance(in_dtype, Sequence) else [in_dtype]
f_in_dtypes = dh.fmt_types(tuple(in_dtypes))
f_out_dtype = dh.dtype_to_name[out_dtype]
Expand All @@ -102,6 +135,14 @@ def assert_dtype(


def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
"""
Assert the output dtype is the passed keyword dtype, e.g.

>>> kw = {'dtype': xp.uint8}
>>> out = xp.ones(5, **kw)
>>> assert_kw_dtype('ones', kw['dtype'], out.dtype)

"""
f_kw_dtype = dh.dtype_to_name[kw_dtype]
f_out_dtype = dh.dtype_to_name[out_dtype]
msg = (
Expand All @@ -111,33 +152,54 @@ def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
assert out_dtype == kw_dtype, msg


def assert_default_float(func_name: str, dtype: DataType):
f_dtype = dh.dtype_to_name[dtype]
def assert_default_float(func_name: str, out_dtype: DataType):
"""
Assert the output dtype is the default float, e.g.

>>> out = xp.ones(5)
>>> assert_default_float('ones', out.dtype)

"""
f_dtype = dh.dtype_to_name[out_dtype]
f_default = dh.dtype_to_name[dh.default_float]
msg = (
f"out.dtype={f_dtype}, should be default "
f"floating-point dtype {f_default} [{func_name}()]"
)
assert dtype == dh.default_float, msg
assert out_dtype == dh.default_float, msg


def assert_default_int(func_name: str, dtype: DataType):
f_dtype = dh.dtype_to_name[dtype]
def assert_default_int(func_name: str, out_dtype: DataType):
"""
Assert the output dtype is the default int, e.g.

>>> out = xp.full(5, 42)
>>> assert_default_int('full', out.dtype)

"""
f_dtype = dh.dtype_to_name[out_dtype]
f_default = dh.dtype_to_name[dh.default_int]
msg = (
f"out.dtype={f_dtype}, should be default "
f"integer dtype {f_default} [{func_name}()]"
)
assert dtype == dh.default_int, msg
assert out_dtype == dh.default_int, msg


def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dtype"):
"""
Assert the output dtype is the default index dtype, e.g.

>>> out = xp.argmax(xp.arange(5))
>>> assert_default_int('argmax', out.dtype)

def assert_default_index(func_name: str, dtype: DataType, repr_name="out.dtype"):
f_dtype = dh.dtype_to_name[dtype]
"""
f_dtype = dh.dtype_to_name[out_dtype]
msg = (
f"{repr_name}={f_dtype}, should be the default index dtype, "
f"which is either int32 or int64 [{func_name}()]"
)
assert dtype in (xp.int32, xp.int64), msg
assert out_dtype in (xp.int32, xp.int64), msg


def assert_shape(
Expand All @@ -148,6 +210,13 @@ def assert_shape(
repr_name="out.shape",
**kw,
):
"""
Assert the output shape is as expected, e.g.

>>> out = xp.ones((3, 3, 3))
>>> assert_shape('ones', out.shape, (3, 3, 3))

"""
if isinstance(out_shape, int):
out_shape = (out_shape,)
if isinstance(expected, int):
Expand All @@ -168,6 +237,20 @@ def assert_result_shape(
repr_name="out.shape",
**kw,
):
"""
Assert the output shape is as expected.

If expected=None, we infer the expected shape as the result of broadcasting
in_shapes, to test against out_shape, e.g.

>>> out = xp.add(xp.ones((3, 1)), xp.ones((1, 3)))
>>> assert_shape('add', [(3, 1), (1, 3)], out.shape)

is equivalent to

>>> assert out.shape == (3, 3)

"""
if expected is None:
expected = sh.broadcast_shapes(*in_shapes)
f_in_shapes = " . ".join(str(s) for s in in_shapes)
Expand All @@ -180,13 +263,28 @@ def assert_result_shape(

def assert_keepdimable_shape(
func_name: str,
out_shape: Shape,
in_shape: Shape,
out_shape: Shape,
axes: Tuple[int, ...],
keepdims: bool,
/,
**kw,
):
"""
Assert the output shape from a keepdimable function is as expected, e.g.

>>> x = xp.asarray([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
>>> out1 = xp.max(x, keepdims=False)
>>> out2 = xp.max(x, keepdims=True)
>>> assert_keepdimable_shape('max', x.shape, out1.shape, (0, 1), False)
>>> assert_keepdimable_shape('max', x.shape, out2.shape, (0, 1), True)

is equivalent to

>>> assert out1.shape == ()
>>> assert out2.shape == (1, 1)

"""
if keepdims:
shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape))
else:
Expand All @@ -197,6 +295,19 @@ def assert_keepdimable_shape(
def assert_0d_equals(
func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw
):
"""
Assert a 0d array is as expected, e.g.

>>> x = xp.asarray([0, 1, 2])
>>> res = xp.asarray(x, copy=True)
>>> res[0] = 42
>>> assert_0d_equals('__setitem__', 'x[0]', x[0], 'x[0]', res[0])

is equivalent to

>>> assert res[0] == x[0]

"""
msg = (
f"{out_repr}={out_val}, but should be {x_repr}={x_val} "
f"[{func_name}({fmt_kw(kw)})]"
Expand All @@ -217,9 +328,21 @@ def assert_scalar_equals(
repr_name: str = "out",
**kw,
):
"""
Assert a 0d array, convered to a scalar, is as expected, e.g.

>>> x = xp.ones(5, dtype=xp.uint8)
>>> out = xp.sum(x)
>>> assert_scalar_equals('sum', int, (), int(out), 5)

is equivalent to

>>> assert int(out) == 5

"""
repr_name = repr_name if idx == () else f"{repr_name}[{idx}]"
f_func = f"{func_name}({fmt_kw(kw)})"
if type_ is bool or type_ is int:
if type_ in [bool, int]:
msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
assert out == expected, msg
elif math.isnan(expected):
Expand All @@ -233,14 +356,37 @@ def assert_scalar_equals(
def assert_fill(
func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw
):
"""
Assert all elements of an array is as expected, e.g.

>>> out = xp.full(5, 42, dtype=xp.uint8)
>>> assert_fill('full', 42, xp.uint8, out, 5)

is equivalent to

>>> assert xp.all(out == 42)

"""
msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}"
if math.isnan(fill_value):
assert ah.all(ah.isnan(out)), msg
assert xp.all(xp.isnan(out)), msg
else:
assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), msg
assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg


def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
"""
Assert array is (strictly) as expected, e.g.

>>> x = xp.arange(5)
>>> out = xp.asarray(x)
>>> assert_array('asarray', out, x)

is equivalent to

>>> assert xp.all(out == x)

"""
assert_dtype(func_name, out.dtype, expected.dtype)
assert_shape(func_name, out.shape, expected.shape, **kw)
f_func = f"[{func_name}({fmt_kw(kw)})]"
Expand Down
21 changes: 10 additions & 11 deletions array_api_tests/test_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from hypothesis import strategies as st

from . import _array_module as xp
from . import array_helpers as ah
from . import dtype_helpers as dh
from . import hypothesis_helpers as hh
from . import pytest_helpers as ph
Expand Down Expand Up @@ -181,12 +180,12 @@ def test_arange(dtype, data):
if dh.is_int_dtype(_dtype):
elements = list(r)
assume(out.size == len(elements))
ah.assert_exactly_equal(out, ah.asarray(elements, dtype=_dtype))
ph.assert_array("arange", out, xp.asarray(elements, dtype=_dtype))
else:
assume(out.size == size)
if out.size > 0:
assert ah.equal(
out[0], ah.asarray(_start, dtype=out.dtype)
assert xp.equal(
out[0], xp.asarray(_start, dtype=out.dtype)
), f"out[0]={out[0]}, but should be {_start} {f_func}"


Expand Down Expand Up @@ -421,8 +420,8 @@ def test_linspace(num, dtype, endpoint, data):
start = data.draw(xps.from_dtype(_dtype, **finite_kw), label="start")
stop = data.draw(xps.from_dtype(_dtype, **finite_kw), label="stop")
# avoid overflow errors
assume(not ah.isnan(ah.asarray(stop - start, dtype=_dtype)))
assume(not ah.isnan(ah.asarray(start - stop, dtype=_dtype)))
assume(not xp.isnan(xp.asarray(stop - start, dtype=_dtype)))
assume(not xp.isnan(xp.asarray(start - stop, dtype=_dtype)))

kw = data.draw(
hh.specified_kwargs(
Expand All @@ -440,20 +439,20 @@ def test_linspace(num, dtype, endpoint, data):
ph.assert_shape("linspace", out.shape, num, start=stop, stop=stop, num=num)
f_func = f"[linspace({start}, {stop}, {num})]"
if num > 0:
assert ah.equal(
out[0], ah.asarray(start, dtype=out.dtype)
assert xp.equal(
out[0], xp.asarray(start, dtype=out.dtype)
), f"out[0]={out[0]}, but should be {start} {f_func}"
if endpoint:
if num > 1:
assert ah.equal(
out[-1], ah.asarray(stop, dtype=out.dtype)
assert xp.equal(
out[-1], xp.asarray(stop, dtype=out.dtype)
), f"out[-1]={out[-1]}, but should be {stop} {f_func}"
else:
# linspace(..., num, endpoint=True) should return an array equivalent to
# the first num elements when endpoint=False
expected = xp.linspace(start, stop, num + 1, dtype=dtype, endpoint=True)
expected = expected[:-1]
ah.assert_exactly_equal(out, expected)
ph.assert_array("linspace", out, expected)


@given(dtype=xps.numeric_dtypes(), data=st.data())
Expand Down
Loading