Skip to content

Use keyword-only arguments for the pytest helpers functions #176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions array_api_tests/meta/test_pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@


def test_assert_dtype():
ph.assert_dtype("promoted_func", [xp.uint8, xp.int8], xp.int16)
ph.assert_dtype("promoted_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.int16)
with raises(AssertionError):
ph.assert_dtype("bad_func", [xp.uint8, xp.int8], xp.float32)
ph.assert_dtype("bool_func", [xp.uint8, xp.int8], xp.bool, xp.bool)
ph.assert_dtype("single_promoted_func", [xp.uint8], xp.uint8)
ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool)
ph.assert_dtype("bad_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.float32)
ph.assert_dtype("bool_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.bool, expected=xp.bool)
ph.assert_dtype("single_promoted_func", in_dtype=[xp.uint8], out_dtype=xp.uint8)
ph.assert_dtype("single_bool_func", in_dtype=[xp.uint8], out_dtype=xp.bool, expected=xp.bool)


def test_assert_array_elements():
ph.assert_array_elements("int zeros", xp.asarray(0), xp.asarray(0))
ph.assert_array_elements("pos zeros", xp.asarray(0.0), xp.asarray(0.0))
ph.assert_array_elements("int zeros", out=xp.asarray(0), expected=xp.asarray(0))
ph.assert_array_elements("pos zeros", out=xp.asarray(0.0), expected=xp.asarray(0.0))
with raises(AssertionError):
ph.assert_array_elements("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0))
ph.assert_array_elements("mixed sign zeros", out=xp.asarray(0.0), expected=xp.asarray(-0.0))
with raises(AssertionError):
ph.assert_array_elements("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0))
ph.assert_array_elements("mixed sign zeros", out=xp.asarray(-0.0), expected=xp.asarray(0.0))
79 changes: 50 additions & 29 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ def is_neg_zero(n: float) -> bool:

def assert_dtype(
func_name: str,
*,
in_dtype: Union[DataType, Sequence[DataType]],
out_dtype: DataType,
expected: Optional[DataType] = None,
*,
repr_name: str = "out.dtype",
):
"""
Expand All @@ -96,7 +96,7 @@ def assert_dtype(

>>> x = xp.arange(5, dtype=xp.uint8)
>>> out = xp.abs(x)
>>> assert_dtype('abs', x.dtype, out.dtype)
>>> assert_dtype('abs', in_dtype=x.dtype, out_dtype=out.dtype)

is equivalent to

Expand All @@ -108,7 +108,7 @@ def assert_dtype(
>>> x1 = xp.arange(5, dtype=xp.uint8)
>>> x2 = xp.arange(5, dtype=xp.uint16)
>>> out = xp.add(x1, x2)
>>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype)
>>> assert_dtype('add', in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)

is equivalent to

Expand All @@ -119,7 +119,7 @@ def assert_dtype(
>>> x = xp.arange(5, dtype=xp.int8)
>>> out = xp.sum(x)
>>> default_int = xp.asarray(0).dtype
>>> assert_dtype('sum', x, out.dtype, default_int)
>>> assert_dtype('sum', in_dtype=x, out_dtype=out.dtype, expected=default_int)

"""
in_dtypes = in_dtype if isinstance(in_dtype, Sequence) and not isinstance(in_dtype, str) else [in_dtype]
Expand All @@ -135,13 +135,18 @@ def assert_dtype(
assert out_dtype == expected, msg


def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
def assert_kw_dtype(
func_name: str,
*,
kw_dtype: DataType,
out_dtype: DataType,
):
"""
Assert the output dtype is the passed keyword dtype, e.g.

>>> kw = {'dtype': xp.uint8}
>>> out = xp.ones(5, **kw)
>>> assert_kw_dtype('ones', kw['dtype'], out.dtype)
>>> out = xp.ones(5, kw=kw)
>>> assert_kw_dtype('ones', kw_dtype=kw['dtype'], out_dtype=out.dtype)

"""
f_kw_dtype = dh.dtype_to_name[kw_dtype]
Expand Down Expand Up @@ -222,17 +227,17 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty

def assert_shape(
func_name: str,
*,
out_shape: Union[int, Shape],
expected: Union[int, Shape],
/,
repr_name="out.shape",
**kw,
kw: dict = {},
):
"""
Assert the output shape is as expected, e.g.

>>> out = xp.ones((3, 3, 3))
>>> assert_shape('ones', out.shape, (3, 3, 3))
>>> assert_shape('ones', out_shape=out.shape, expected=(3, 3, 3))

"""
if isinstance(out_shape, int):
Expand All @@ -249,11 +254,10 @@ def assert_result_shape(
func_name: str,
in_shapes: Sequence[Shape],
out_shape: Shape,
/,
expected: Optional[Shape] = None,
*,
repr_name="out.shape",
**kw,
kw: dict = {},
):
"""
Assert the output shape is as expected.
Expand All @@ -262,7 +266,7 @@ def assert_result_shape(
in_shapes, to test against out_shape, e.g.

>>> out = xp.add(xp.ones((3, 1)), xp.ones((1, 3)))
>>> assert_shape('add', [(3, 1), (1, 3)], out.shape)
>>> assert_result_shape('add', in_shape=[(3, 1), (1, 3)], out_shape=out.shape)

is equivalent to

Expand All @@ -281,21 +285,21 @@ def assert_result_shape(

def assert_keepdimable_shape(
func_name: str,
*,
in_shape: Shape,
out_shape: Shape,
axes: Tuple[int, ...],
keepdims: bool,
/,
**kw,
kw: dict = {},
):
"""
Assert the output shape from a keepdimable function is as expected, e.g.

>>> x = xp.asarray([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
>>> out1 = xp.max(x, keepdims=False)
>>> out2 = xp.max(x, keepdims=True)
>>> assert_keepdimable_shape('max', x.shape, out1.shape, (0, 1), False)
>>> assert_keepdimable_shape('max', x.shape, out2.shape, (0, 1), True)
>>> assert_keepdimable_shape('max', in_shape=x.shape, out_shape=out1.shape, axes=(0, 1), keepdims=False)
>>> assert_keepdimable_shape('max', in_shape=x.shape, out_shape=out2.shape, axes=(0, 1), keepdims=True)

is equivalent to

Expand All @@ -307,19 +311,26 @@ def assert_keepdimable_shape(
shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape))
else:
shape = tuple(side for axis, side in enumerate(in_shape) if axis not in axes)
assert_shape(func_name, out_shape, shape, **kw)
assert_shape(func_name, out_shape=out_shape, expected=shape, kw=kw)


def assert_0d_equals(
func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw
func_name: str,
*,
x_repr: str,
x_val: Array,
out_repr: str,
out_val: Array,
kw: dict = {},
):
"""
Assert a 0d array is as expected, e.g.

>>> x = xp.asarray([0, 1, 2])
>>> res = xp.asarray(x, copy=True)
>>> kw = {'copy': True}
>>> res = xp.asarray(x, **kw)
>>> res[0] = 42
>>> assert_0d_equals('asarray', 'x[0]', x[0], 'x[0]', res[0])
>>> assert_0d_equals('asarray', x_repr='x[0]', x_val=x[0], out_repr='x[0]', out_val=res[0], kw=kw)

is equivalent to

Expand All @@ -338,20 +349,20 @@ def assert_0d_equals(

def assert_scalar_equals(
func_name: str,
*,
type_: ScalarType,
idx: Shape,
out: Scalar,
expected: Scalar,
/,
repr_name: str = "out",
**kw,
kw: dict = {},
):
"""
Assert a 0d array, convered to a scalar, is as expected, e.g.

>>> x = xp.ones(5, dtype=xp.uint8)
>>> out = xp.sum(x)
>>> assert_scalar_equals('sum', int, (), int(out), 5)
>>> assert_scalar_equals('sum', type_int, out=(), out=int(out), expected=5)

is equivalent to

Expand All @@ -372,13 +383,18 @@ def assert_scalar_equals(


def assert_fill(
func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw
func_name: str,
*,
fill_value: Scalar,
dtype: DataType,
out: Array,
kw: dict = {},
):
"""
Assert all elements of an array is as expected, e.g.

>>> out = xp.full(5, 42, dtype=xp.uint8)
>>> assert_fill('full', 42, xp.uint8, out, 5)
>>> assert_fill('full', fill_value=42, dtype=xp.uint8, out=out, kw=dict(shape=5))

is equivalent to

Expand Down Expand Up @@ -408,22 +424,27 @@ def _assert_float_element(at_out: Array, at_expected: Array, msg: str):


def assert_array_elements(
func_name: str, out: Array, expected: Array, /, *, out_repr: str = "out", **kw
func_name: str,
*,
out: Array,
expected: Array,
out_repr: str = "out",
kw: dict = {},
):
"""
Assert array elements are (strictly) as expected, e.g.

>>> x = xp.arange(5)
>>> out = xp.asarray(x)
>>> assert_array_elements('asarray', out, x)
>>> assert_array_elements('asarray', out=out, expected=x)

is equivalent to

>>> assert xp.all(out == x)

"""
dh.result_type(out.dtype, expected.dtype) # sanity check
assert_shape(func_name, out.shape, expected.shape, **kw) # sanity check
assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
f_func = f"[{func_name}({fmt_kw(kw)})]"
if out.dtype in dh.float_dtypes:
for idx in sh.ndindex(out.shape):
Expand Down
Loading