diff --git a/array_api_tests/array_helpers.py b/array_api_tests/array_helpers.py index b3ae583c..ef4f719a 100644 --- a/array_api_tests/array_helpers.py +++ b/array_api_tests/array_helpers.py @@ -306,14 +306,3 @@ def same_sign(x, y): def assert_same_sign(x, y): assert all(same_sign(x, y)), "The input arrays do not have the same sign" -def int_to_dtype(x, n, signed): - """ - Convert the Python integer x into an n bit signed or unsigned number. - """ - mask = (1 << n) - 1 - x &= mask - if signed: - highest_bit = 1 << (n-1) - if x & highest_bit: - x = -((~x & mask) + 1) - return x diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index a0adc8c9..38771225 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -16,7 +16,6 @@ from ._array_module import _UndefinedStub from ._array_module import bool as bool_dtype from ._array_module import broadcast_to, eye, float32, float64, full -from .algos import broadcast_shapes from .function_stubs import elementwise_functions from .pytest_helpers import nargs from .typing import Array, DataType, Shape @@ -243,7 +242,7 @@ def two_broadcastable_shapes(draw): broadcast to shape1. """ shape1, shape2 = draw(two_mutually_broadcastable_shapes) - assume(broadcast_shapes(shape1, shape2) == shape1) + assume(sh.broadcast_shapes(shape1, shape2) == shape1) return (shape1, shape2) sizes = integers(0, MAX_ARRAY_SIZE) @@ -370,6 +369,9 @@ def two_mutual_arrays( ) -> Tuple[SearchStrategy[Array], SearchStrategy[Array]]: if not isinstance(dtypes, Sequence): raise TypeError(f"{dtypes=} not a sequence") + if FILTER_UNDEFINED_DTYPES: + dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)] + assert len(dtypes) > 0 # sanity check mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes)) mutual_shapes = shared(two_shapes) arrays1 = xps.arrays( diff --git a/array_api_tests/meta/test_array_helpers.py b/array_api_tests/meta/test_array_helpers.py index 6a6b4849..68f96910 100644 --- a/array_api_tests/meta/test_array_helpers.py +++ b/array_api_tests/meta/test_array_helpers.py @@ -1,10 +1,5 @@ -from hypothesis import given, assume -from hypothesis.strategies import integers - -from ..array_helpers import exactly_equal, notequal, int_to_dtype -from ..hypothesis_helpers import integer_dtypes -from ..dtype_helpers import dtype_nbits, dtype_signed from .. import _array_module as xp +from ..array_helpers import exactly_equal, notequal # TODO: These meta-tests currently only work with NumPy @@ -22,12 +17,3 @@ def test_notequal(): res = xp.asarray([False, True, False, False, False, True, False, True]) assert xp.all(xp.equal(notequal(a, b), res)) -@given(integers(), integer_dtypes) -def test_int_to_dtype(x, dtype): - n = dtype_nbits[dtype] - signed = dtype_signed[dtype] - try: - d = xp.asarray(x, dtype=dtype) - except OverflowError: - assume(False) - assert int_to_dtype(x, n, signed) == d diff --git a/array_api_tests/meta/test_broadcasting.py b/array_api_tests/meta/test_broadcasting.py index e347e525..72de61cf 100644 --- a/array_api_tests/meta/test_broadcasting.py +++ b/array_api_tests/meta/test_broadcasting.py @@ -4,7 +4,7 @@ import pytest -from ..algos import BroadcastError, _broadcast_shapes +from .. import shape_helpers as sh @pytest.mark.parametrize( @@ -19,7 +19,7 @@ ], ) def test_broadcast_shapes(shape1, shape2, expected): - assert _broadcast_shapes(shape1, shape2) == expected + assert sh._broadcast_shapes(shape1, shape2) == expected @pytest.mark.parametrize( @@ -31,5 +31,5 @@ def test_broadcast_shapes(shape1, shape2, expected): ], ) def test_broadcast_shapes_fails_on_bad_shapes(shape1, shape2): - with pytest.raises(BroadcastError): - _broadcast_shapes(shape1, shape2) + with pytest.raises(sh.BroadcastError): + sh._broadcast_shapes(shape1, shape2) diff --git a/array_api_tests/meta/test_hypothesis_helpers.py b/array_api_tests/meta/test_hypothesis_helpers.py index b4cb6e96..647cc145 100644 --- a/array_api_tests/meta/test_hypothesis_helpers.py +++ b/array_api_tests/meta/test_hypothesis_helpers.py @@ -8,9 +8,9 @@ from .. import array_helpers as ah from .. import dtype_helpers as dh from .. import hypothesis_helpers as hh +from .. import shape_helpers as sh from .. import xps from .._array_module import _UndefinedStub -from ..algos import broadcast_shapes UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes) pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")] @@ -62,7 +62,7 @@ def test_two_mutually_broadcastable_shapes(pair): def test_two_broadcastable_shapes(pair): for shape in pair: assert valid_shape(shape) - assert broadcast_shapes(pair[0], pair[1]) == pair[0] + assert sh.broadcast_shapes(pair[0], pair[1]) == pair[0] @given(*hh.two_mutual_arrays()) diff --git a/array_api_tests/meta/test_pytest_helpers.py b/array_api_tests/meta/test_pytest_helpers.py index 9b0f4fad..21da2264 100644 --- a/array_api_tests/meta/test_pytest_helpers.py +++ b/array_api_tests/meta/test_pytest_helpers.py @@ -5,9 +5,9 @@ def test_assert_dtype(): - ph.assert_dtype("promoted_func", (xp.uint8, xp.int8), xp.int16) + ph.assert_dtype("promoted_func", [xp.uint8, xp.int8], xp.int16) with raises(AssertionError): - ph.assert_dtype("bad_func", (xp.uint8, xp.int8), xp.float32) - ph.assert_dtype("bool_func", (xp.uint8, xp.int8), xp.bool, xp.bool) - ph.assert_dtype("single_promoted_func", (xp.uint8,), xp.uint8) - ph.assert_dtype("single_bool_func", (xp.uint8,), xp.bool, xp.bool) + ph.assert_dtype("bad_func", [xp.uint8, xp.int8], xp.float32) + ph.assert_dtype("bool_func", [xp.uint8, xp.int8], xp.bool, xp.bool) + ph.assert_dtype("single_promoted_func", [xp.uint8], xp.uint8) + ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool) diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 3b28b9a9..3cd819b4 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -1,8 +1,13 @@ import pytest +from hypothesis import given, reject +from hypothesis import strategies as st +from .. import _array_module as xp +from .. import xps from .. import shape_helpers as sh from ..test_creation_functions import frange from ..test_manipulation_functions import roll_ndindex +from ..test_operators_and_elementwise_functions import mock_int_dtype from ..test_signatures import extension_module @@ -82,3 +87,31 @@ def test_axes_ndindex(shape, axes, expected): ) def test_roll_ndindex(shape, shifts, axes, expected): assert list(roll_ndindex(shape, shifts, axes)) == expected + + +@pytest.mark.parametrize( + "idx, expected", + [ + ((), "x"), + (42, "x[42]"), + ((42,), "x[42]"), + (slice(None, 2), "x[:2]"), + (slice(2, None), "x[2:]"), + (slice(0, 2), "x[0:2]"), + (slice(0, 2, -1), "x[0:2:-1]"), + (slice(None, None, -1), "x[::-1]"), + (slice(None, None), "x[:]"), + (..., "x[...]"), + ], +) +def test_fmt_idx(idx, expected): + assert sh.fmt_idx("x", idx) == expected + + +@given(x=st.integers(), dtype=xps.unsigned_integer_dtypes() | xps.integer_dtypes()) +def test_int_to_dtype(x, dtype): + try: + d = xp.asarray(x, dtype=dtype) + except OverflowError: + reject() + assert mock_int_dtype(x, dtype) == d diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index d1b48830..9a5ffbb2 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -1,12 +1,12 @@ import math from inspect import getfullargspec -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union from . import _array_module as xp from . import array_helpers as ah from . import dtype_helpers as dh from . import function_stubs -from .algos import broadcast_shapes +from . import shape_helpers as sh from .typing import Array, DataType, Scalar, ScalarType, Shape __all__ = [ @@ -71,15 +71,14 @@ def fmt_kw(kw: Dict[str, Any]) -> str: def assert_dtype( func_name: str, - in_dtypes: Union[DataType, Tuple[DataType, ...]], + in_dtype: Union[DataType, Sequence[DataType]], out_dtype: DataType, expected: Optional[DataType] = None, *, repr_name: str = "out.dtype", ): - if not isinstance(in_dtypes, tuple): - in_dtypes = (in_dtypes,) - f_in_dtypes = dh.fmt_types(in_dtypes) + in_dtypes = in_dtype if isinstance(in_dtype, Sequence) else [in_dtype] + f_in_dtypes = dh.fmt_types(tuple(in_dtypes)) f_out_dtype = dh.dtype_to_name[out_dtype] if expected is None: expected = dh.result_type(*in_dtypes) @@ -150,7 +149,7 @@ def assert_shape( def assert_result_shape( func_name: str, - in_shapes: Tuple[Shape], + in_shapes: Sequence[Shape], out_shape: Shape, /, expected: Optional[Shape] = None, @@ -159,7 +158,7 @@ def assert_result_shape( **kw, ): if expected is None: - expected = broadcast_shapes(*in_shapes) + expected = sh.broadcast_shapes(*in_shapes) f_in_shapes = " . ".join(str(s) for s in in_shapes) f_sig = f" {f_in_shapes} " if kw: diff --git a/array_api_tests/shape_helpers.py b/array_api_tests/shape_helpers.py index 17dd7f6e..9b3d001b 100644 --- a/array_api_tests/shape_helpers.py +++ b/array_api_tests/shape_helpers.py @@ -2,9 +2,67 @@ from itertools import product from typing import Iterator, List, Optional, Tuple, Union -from .typing import Scalar, Shape +from ndindex import iter_indices as _iter_indices -__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex", "reshape"] +from .typing import AtomicIndex, Index, Scalar, Shape + +__all__ = [ + "broadcast_shapes", + "normalise_axis", + "ndindex", + "axis_ndindex", + "axes_ndindex", + "reshape", + "fmt_idx", +] + + +class BroadcastError(ValueError): + """Shapes do not broadcast with eachother""" + + +def _broadcast_shapes(shape1: Shape, shape2: Shape) -> Shape: + """Broadcasts `shape1` and `shape2`""" + N1 = len(shape1) + N2 = len(shape2) + N = max(N1, N2) + shape = [None for _ in range(N)] + i = N - 1 + while i >= 0: + n1 = N1 - N + i + if N1 - N + i >= 0: + d1 = shape1[n1] + else: + d1 = 1 + n2 = N2 - N + i + if N2 - N + i >= 0: + d2 = shape2[n2] + else: + d2 = 1 + + if d1 == 1: + shape[i] = d2 + elif d2 == 1: + shape[i] = d1 + elif d1 == d2: + shape[i] = d1 + else: + raise BroadcastError() + + i = i - 1 + + return tuple(shape) + + +def broadcast_shapes(*shapes: Shape): + if len(shapes) == 0: + raise ValueError("shapes=[] must be non-empty") + elif len(shapes) == 1: + return shapes[0] + result = _broadcast_shapes(shapes[0], shapes[1]) + for i in range(2, len(shapes)): + result = _broadcast_shapes(result, shapes[i]) + return result def normalise_axis( @@ -17,13 +75,21 @@ def normalise_axis( return axes -def ndindex(shape): - """Iterator of n-D indices to an array +def ndindex(shape: Shape) -> Iterator[Index]: + """Yield every index of a shape""" + return (indices[0] for indices in iter_indices(shape)) + - Yields tuples of integers to index every element of an array of shape - `shape`. Same as np.ndindex(). - """ - return product(*[range(i) for i in shape]) +def iter_indices( + *shapes: Shape, skip_axes: Tuple[int, ...] = () +) -> Iterator[Tuple[Index, ...]]: + """Wrapper for ndindex.iter_indices()""" + # Prevent iterations if any shape has 0-sides + for shape in shapes: + if 0 in shape: + return + for indices in _iter_indices(*shapes, skip_axes=skip_axes): + yield tuple(i.raw for i in indices) # type: ignore def axis_ndindex( @@ -60,7 +126,7 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]: yield list(indices) -def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]]: +def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List]: """Reshape a flat sequence""" if any(s == 0 for s in shape): raise ValueError( @@ -75,3 +141,33 @@ def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]] size = len(flat_seq) n = math.prod(shape[1:]) return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)] + + +def fmt_i(i: AtomicIndex) -> str: + if isinstance(i, int): + return str(i) + elif isinstance(i, slice): + res = "" + if i.start is not None: + res += str(i.start) + res += ":" + if i.stop is not None: + res += str(i.stop) + if i.step is not None: + res += f":{i.step}" + return res + else: + return "..." + + +def fmt_idx(sym: str, idx: Index) -> str: + if idx == (): + return sym + res = f"{sym}[" + _idx = idx if isinstance(idx, tuple) else (idx,) + if len(_idx) == 1: + res += fmt_i(_idx[0]) + else: + res += ", ".join(fmt_i(i) for i in _idx) + res += "]" + return res diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 9d6e7fe1..a81339d0 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -152,7 +152,7 @@ def test_arange(dtype, data): else: ph.assert_default_float("arange", out.dtype) else: - ph.assert_dtype("arange", (out.dtype,), dtype) + ph.assert_kw_dtype("arange", dtype, out.dtype) f_sig = ", ".join(str(n) for n in args) if len(kwargs) > 0: f_sig += f", {ph.fmt_kw(kwargs)}" @@ -302,7 +302,7 @@ def test_empty(shape, kw): def test_empty_like(x, kw): out = xp.empty_like(x, **kw) if kw.get("dtype", None) is None: - ph.assert_dtype("empty_like", (x.dtype,), out.dtype) + ph.assert_dtype("empty_like", x.dtype, out.dtype) else: ph.assert_kw_dtype("empty_like", kw["dtype"], out.dtype) ph.assert_shape("empty_like", out.shape, x.shape) @@ -399,7 +399,7 @@ def test_full_like(x, fill_value, kw): out = xp.full_like(x, fill_value, **kw) dtype = kw.get("dtype", None) or x.dtype if kw.get("dtype", None) is None: - ph.assert_dtype("full_like", (x.dtype,), out.dtype) + ph.assert_dtype("full_like", x.dtype, out.dtype) else: ph.assert_kw_dtype("full_like", kw["dtype"], out.dtype) ph.assert_shape("full_like", out.shape, x.shape) @@ -459,7 +459,7 @@ def test_linspace(num, dtype, endpoint, data): if dtype is None: ph.assert_default_float("linspace", out.dtype) else: - ph.assert_dtype("linspace", (out.dtype,), dtype) + ph.assert_kw_dtype("linspace", dtype, out.dtype) ph.assert_shape("linspace", out.shape, num, start=stop, stop=stop, num=num) f_func = f"[linspace({start}, {stop}, {num})]" if num > 0: @@ -529,7 +529,7 @@ def test_ones(shape, kw): def test_ones_like(x, kw): out = xp.ones_like(x, **kw) if kw.get("dtype", None) is None: - ph.assert_dtype("ones_like", (x.dtype,), out.dtype) + ph.assert_dtype("ones_like", x.dtype, out.dtype) else: ph.assert_kw_dtype("ones_like", kw["dtype"], out.dtype) ph.assert_shape("ones_like", out.shape, x.shape) @@ -565,7 +565,7 @@ def test_zeros(shape, kw): def test_zeros_like(x, kw): out = xp.zeros_like(x, **kw) if kw.get("dtype", None) is None: - ph.assert_dtype("zeros_like", (x.dtype,), out.dtype) + ph.assert_dtype("zeros_like", x.dtype, out.dtype) else: ph.assert_kw_dtype("zeros_like", kw["dtype"], out.dtype) ph.assert_shape("zeros_like", out.shape, x.shape) diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index ded82682..763c71a4 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -9,8 +9,8 @@ from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph +from . import shape_helpers as sh from . import xps -from .algos import broadcast_shapes from .typing import DataType pytestmark = pytest.mark.ci @@ -70,7 +70,7 @@ def test_broadcast_arrays(shapes, data): out = xp.broadcast_arrays(*arrays) - out_shape = broadcast_shapes(*shapes) + out_shape = sh.broadcast_shapes(*shapes) for i, x in enumerate(arrays): ph.assert_dtype( "broadcast_arrays", x.dtype, out[i].dtype, repr_name=f"out[{i}].dtype" @@ -90,7 +90,7 @@ def test_broadcast_to(x, data): shape = data.draw( hh.mutually_broadcastable_shapes(1, base_shape=x.shape) .map(lambda S: S[0]) - .filter(lambda s: broadcast_shapes(x.shape, s) == s), + .filter(lambda s: sh.broadcast_shapes(x.shape, s) == s), label="shape", ) diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 62c93562..7117c20b 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -31,8 +31,6 @@ from . import pytest_helpers as ph from . import shape_helpers as sh -from .algos import broadcast_shapes - from . import _array_module from . import _array_module as xp from ._array_module import linalg @@ -299,7 +297,7 @@ def test_matmul(x1, x2): else: res = _array_module.matmul(x1, x2) - ph.assert_dtype("matmul", (x1.dtype, x2.dtype), res.dtype) + ph.assert_dtype("matmul", [x1.dtype, x2.dtype], res.dtype) if len(x1.shape) == len(x2.shape) == 1: assert res.shape == () @@ -310,7 +308,7 @@ def test_matmul(x1, x2): assert res.shape == x1.shape[:-1] _test_stacks(_array_module.matmul, x1, x2, res=res, dims=1) else: - stack_shape = broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) + stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1]) _test_stacks(_array_module.matmul, x1, x2, res=res) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 1ae28919..b9d9e03d 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -142,7 +142,7 @@ def test_expand_dims(x, axis): index = axis if axis >= 0 else x.ndim + axis + 1 shape.insert(index, 1) shape = tuple(shape) - ph.assert_result_shape("expand_dims", (x.shape,), out.shape, shape) + ph.assert_result_shape("expand_dims", [x.shape], out.shape, shape) assert_array_ndindex( "expand_dims", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape) @@ -181,7 +181,7 @@ def test_squeeze(x, data): if i not in axes: shape.append(side) shape = tuple(shape) - ph.assert_result_shape("squeeze", (x.shape,), out.shape, shape, axis=axis) + ph.assert_result_shape("squeeze", [x.shape], out.shape, shape, axis=axis) assert_array_ndindex("squeeze", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape)) @@ -230,7 +230,7 @@ def test_permute_dims(x, axes): side = x.shape[dim] shape[i] = side shape = tuple(shape) - ph.assert_result_shape("permute_dims", (x.shape,), out.shape, shape, axes=axes) + ph.assert_result_shape("permute_dims", [x.shape], out.shape, shape, axes=axes) indices = list(sh.ndindex(x.shape)) permuted_indices = [tuple(idx[axis] for axis in axes) for idx in indices] @@ -265,7 +265,7 @@ def test_reshape(x, data): rsize = math.prod(shape) * -1 _shape[shape.index(-1)] = size / rsize _shape = tuple(_shape) - ph.assert_result_shape("reshape", (x.shape,), out.shape, _shape, shape=shape) + ph.assert_result_shape("reshape", [x.shape], out.shape, _shape, shape=shape) assert_array_ndindex("reshape", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape)) @@ -303,7 +303,7 @@ def test_roll(x, data): ph.assert_dtype("roll", x.dtype, out.dtype) - ph.assert_result_shape("roll", (x.shape,), out.shape) + ph.assert_result_shape("roll", [x.shape], out.shape) if kw.get("axis", None) is None: assert isinstance(shift, int) # sanity check diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 5da74014..0eb15462 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1,17 +1,7 @@ -""" -Tests for elementwise functions - -https://data-apis.github.io/array-api/latest/API_specification/elementwise_functions.html - -This tests behavior that is explicitly mentioned in the spec. Note that the -spec does not make any accuracy requirements for functions, so this does not -test that. Tests for the special cases are generated and tested separately in -special_cases/ -""" - import math +import operator from enum import Enum, auto -from typing import Callable, List, Optional, Sequence, Union +from typing import Callable, List, NamedTuple, Optional, TypeVar, Union import pytest from hypothesis import assume, given @@ -25,67 +15,267 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from .algos import broadcast_shapes -from .typing import Array, DataType, Param, Scalar +from .typing import Array, DataType, Param, Scalar, ScalarType, Shape pytestmark = pytest.mark.ci + +def all_integer_dtypes() -> st.SearchStrategy[DataType]: + """Returns a strategy for signed and unsigned integer dtype objects.""" + return xps.unsigned_integer_dtypes() | xps.integer_dtypes() + + +def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: + """Returns a strategy for boolean and all integer dtype objects.""" + return xps.boolean_dtypes() | all_integer_dtypes() + + +def mock_int_dtype(n: int, dtype: DataType) -> int: + """Returns equivalent of `n` that mocks `dtype` behaviour.""" + nbits = dh.dtype_nbits[dtype] + mask = (1 << nbits) - 1 + n &= mask + if dh.dtype_signed[dtype]: + highest_bit = 1 << (nbits - 1) + if n & highest_bit: + n = -((~n & mask) + 1) + return n + + +# This module tests elementwise functions/operators against a reference +# implementation. We iterate through the input array(s) and resulting array, +# casting the indexed arrays to Python scalars and calculating the expected +# output with `refimpl` function. +# +# This is finicky to refactor, but possible and ultimately worthwhile - hence +# why these *_assert_again_refimpl() utilities exist. +# +# Values which are special-cased are generated and passed, but are filtered by +# the `filter_` callable before they can be asserted against `refimpl`. We +# automatically generate tests for special cases in the special_cases/ dir. We +# still pass them here so as to ensure their presence doesn't affect the outputs +# respective to non-special-cased elements. +# +# By default, results are casted to scalars the same way that the inputs are. +# You can specify a cast via `res_stype, i.e. when a function accepts numerical +# inputs but returns boolean arrays. +# +# By default, floating-point functions/methods are loosely asserted against. Use +# `strict_check=True` when they should be strictly asserted against, i.e. +# when a function should return intergrals. Likewise, use `strict_check=False` +# when integer function/methods should be loosely asserted against, i.e. when +# floats are used internally for optimisation or legacy reasons. + + +def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bool: + """Wraps math.isclose with very generous defaults. + + This is useful for many floating-point operations where the spec does not + make accuracy requirements. + """ + if not (math.isfinite(a) and math.isfinite(b)): + raise ValueError(f"{a=} and {b=}, but input must be finite") + return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) + + +def default_filter(s: Scalar) -> bool: + """Returns False when s is a non-finite or a signed zero. + + Used by default as these values are typically special-cased. + """ + return math.isfinite(s) and s is not -0.0 and s is not +0.0 + + +T = TypeVar("T") + + +def unary_assert_against_refimpl( + func_name: str, + in_: Array, + res: Array, + refimpl: Callable[[T], T], + expr_template: Optional[str] = None, + res_stype: Optional[ScalarType] = None, + filter_: Callable[[Scalar], bool] = default_filter, + strict_check: Optional[bool] = None, +): + if in_.shape != res.shape: + raise ValueError(f"{res.shape=}, but should be {in_.shape=}") + if expr_template is None: + expr_template = func_name + "({})={}" + in_stype = dh.get_scalar_type(in_.dtype) + if res_stype is None: + res_stype = in_stype + m, M = dh.dtype_ranges.get(res.dtype, (None, None)) + for idx in sh.ndindex(in_.shape): + scalar_i = in_stype(in_[idx]) + if not filter_(scalar_i): + continue + try: + expected = refimpl(scalar_i) + except Exception: + continue + if res.dtype != xp.bool: + assert m is not None and M is not None # for mypy + if expected <= m or expected >= M: + continue + scalar_o = res_stype(res[idx]) + f_i = sh.fmt_idx("x", idx) + f_o = sh.fmt_idx("out", idx) + expr = expr_template.format(f_i, expected) + if strict_check == False or dh.is_float_dtype(res.dtype): + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" + f"{f_i}={scalar_i}" + ) + else: + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" + f"{f_i}={scalar_i}" + ) + + +def binary_assert_against_refimpl( + func_name: str, + left: Array, + right: Array, + res: Array, + refimpl: Callable[[T, T], T], + expr_template: Optional[str] = None, + res_stype: Optional[ScalarType] = None, + left_sym: str = "x1", + right_sym: str = "x2", + res_name: str = "out", + filter_: Callable[[Scalar], bool] = default_filter, + strict_check: Optional[bool] = None, +): + if expr_template is None: + expr_template = func_name + "({}, {})={}" + in_stype = dh.get_scalar_type(left.dtype) + if res_stype is None: + res_stype = in_stype + m, M = dh.dtype_ranges.get(res.dtype, (None, None)) + for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): + scalar_l = in_stype(left[l_idx]) + scalar_r = in_stype(right[r_idx]) + if not (filter_(scalar_l) and filter_(scalar_r)): + continue + try: + expected = refimpl(scalar_l, scalar_r) + except Exception: + continue + if res.dtype != xp.bool: + assert m is not None and M is not None # for mypy + if expected <= m or expected >= M: + continue + scalar_o = res_stype(res[o_idx]) + f_l = sh.fmt_idx(left_sym, l_idx) + f_r = sh.fmt_idx(right_sym, r_idx) + f_o = sh.fmt_idx(res_name, o_idx) + expr = expr_template.format(f_l, f_r, expected) + if strict_check == False or dh.is_float_dtype(res.dtype): + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" + f"{f_l}={scalar_l}, {f_r}={scalar_r}" + ) + else: + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" + f"{f_l}={scalar_l}, {f_r}={scalar_r}" + ) + + +def right_scalar_assert_against_refimpl( + func_name: str, + left: Array, + right: Scalar, + res: Array, + refimpl: Callable[[T, T], T], + expr_template: str = None, + res_stype: Optional[ScalarType] = None, + left_sym: str = "x1", + res_name: str = "out", + filter_: Callable[[Scalar], bool] = default_filter, + strict_check: Optional[bool] = None, +): + if filter_(right): + return # short-circuit here as there will be nothing to test + in_stype = dh.get_scalar_type(left.dtype) + if res_stype is None: + res_stype = in_stype + m, M = dh.dtype_ranges.get(left.dtype, (None, None)) + for idx in sh.ndindex(res.shape): + scalar_l = in_stype(left[idx]) + if not filter_(scalar_l): + continue + try: + expected = refimpl(scalar_l, right) + except Exception: + continue + if left.dtype != xp.bool: + assert m is not None and M is not None # for mypy + if expected <= m or expected >= M: + continue + scalar_o = res_stype(res[idx]) + f_l = sh.fmt_idx(left_sym, idx) + f_o = sh.fmt_idx(res_name, idx) + expr = expr_template.format(f_l, right, expected) + if strict_check == False or dh.is_float_dtype(res.dtype): + assert isclose(scalar_o, expected), ( + f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" + f"{f_l}={scalar_l}" + ) + else: + assert scalar_o == expected, ( + f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" + f"{f_l}={scalar_l}" + ) + + # When appropiate, this module tests operators alongside their respective # elementwise methods. We do this by parametrizing a generalised test method # with every relevant method and operator. # -# Notable arguments in the parameter: +# Notable arguments in the parameter's context object: # - The function object, which for operator test cases is a wrapper that allows # test logic to be generalised. # - The argument strategies, which can be used to draw arguments for the test # case. They may require additional filtering for certain test cases. -# - right_is_scalar (binary parameters), which denotes if the right argument is -# a scalar in a test case. This can be used to appropiately adjust draw -# filtering and test logic. +# - right_is_scalar (binary parameters only), which denotes if the right +# argument is a scalar in a test case. This can be used to appropiately adjust +# draw filtering and test logic. func_to_op = {v: k for k, v in dh.op_to_func.items()} all_op_to_symbol = {**dh.binary_op_to_symbol, **dh.inplace_op_to_symbol} finite_kw = {"allow_nan": False, "allow_infinity": False} -unary_argnames = ("func_name", "func", "strat") -UnaryParam = Param[str, Callable[[Array], Array], st.SearchStrategy[Array]] +class UnaryParamContext(NamedTuple): + func_name: str + func: Callable[[Array], Array] + strat: st.SearchStrategy[Array] -def make_unary_params( - elwise_func_name: str, dtypes: Sequence[DataType] -) -> List[UnaryParam]: - if hh.FILTER_UNDEFINED_DTYPES: - dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)] - strat = xps.arrays(dtype=st.sampled_from(dtypes), shape=hh.shapes()) - func = getattr(xp, elwise_func_name) - op_name = func_to_op[elwise_func_name] - op = lambda x: getattr(x, op_name)() - return [ - pytest.param(elwise_func_name, func, strat, id=elwise_func_name), - pytest.param(op_name, op, strat, id=op_name), - ] + @property + def id(self) -> str: + return f"{self.func_name}" + def __repr__(self): + return f"UnaryParamContext(<{self.id}>)" -binary_argnames = ( - "func_name", - "func", - "left_sym", - "left_strat", - "right_sym", - "right_strat", - "right_is_scalar", - "res_name", -) -BinaryParam = Param[ - str, - Callable[[Array, Union[Scalar, Array]], Array], - str, - st.SearchStrategy[Array], - str, - st.SearchStrategy[Union[Scalar, Array]], - bool, -] + +def make_unary_params( + elwise_func_name: str, dtypes_strat: st.SearchStrategy[DataType] +) -> List[Param[UnaryParamContext]]: + strat = xps.arrays(dtype=dtypes_strat, shape=hh.shapes()) + func_ctx = UnaryParamContext( + func_name=elwise_func_name, func=getattr(xp, elwise_func_name), strat=strat + ) + op_name = func_to_op[elwise_func_name] + op_ctx = UnaryParamContext( + func_name=op_name, func=lambda x: getattr(x, op_name)(), strat=strat + ) + return [pytest.param(func_ctx, id=func_ctx.id), pytest.param(op_ctx, id=op_ctx.id)] class FuncType(Enum): @@ -94,16 +284,33 @@ class FuncType(Enum): IOP = auto() -def make_binary_params( - elwise_func_name: str, dtypes: Sequence[DataType] -) -> List[BinaryParam]: - if hh.FILTER_UNDEFINED_DTYPES: - dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)] - dtypes_strat = st.sampled_from(dtypes) +shapes_kw = {"min_side": 1} + + +class BinaryParamContext(NamedTuple): + func_name: str + func: Callable[[Array, Union[Scalar, Array]], Array] + left_sym: str + left_strat: st.SearchStrategy[Array] + right_sym: str + right_strat: st.SearchStrategy[Union[Scalar, Array]] + right_is_scalar: bool + res_name: str + + @property + def id(self) -> str: + return f"{self.func_name}({self.left_sym}, {self.right_sym})" + + def __repr__(self): + return f"BinaryParamContext(<{self.id}>)" + +def make_binary_params( + elwise_func_name: str, dtypes_strat: st.SearchStrategy[DataType] +) -> List[Param[BinaryParamContext]]: def make_param( func_name: str, func_type: FuncType, right_is_scalar: bool - ) -> BinaryParam: + ) -> Param[BinaryParamContext]: if right_is_scalar: left_sym = "x" right_sym = "s" @@ -113,17 +320,25 @@ def make_param( shared_dtypes = st.shared(dtypes_strat) if right_is_scalar: - left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes()) + left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes(**shapes_kw)) right_strat = shared_dtypes.flatmap( lambda d: xps.from_dtype(d, **finite_kw) ) else: if func_type is FuncType.IOP: - shared_shapes = st.shared(hh.shapes()) + shared_shapes = st.shared(hh.shapes(**shapes_kw)) left_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes) right_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes) else: - left_strat, right_strat = hh.two_mutual_arrays(dtypes) + mutual_shapes = st.shared( + hh.mutually_broadcastable_shapes(2, **shapes_kw) + ) + left_strat = xps.arrays( + dtype=shared_dtypes, shape=mutual_shapes.map(lambda pair: pair[0]) + ) + right_strat = xps.arrays( + dtype=shared_dtypes, shape=mutual_shapes.map(lambda pair: pair[1]) + ) if func_type is FuncType.FUNC: func = getattr(xp, func_name) @@ -142,9 +357,7 @@ def func(l: Array, r: Union[Scalar, Array]) -> Array: def func(l: Array, r: Union[Scalar, Array]) -> Array: locals_ = {} - locals_[left_sym] = ah.asarray( - l, copy=True - ) # prevents left mutating + locals_[left_sym] = ah.asarray(l, copy=True) # prevents mutating l locals_[right_sym] = r exec(expr, locals_) return locals_[left_sym] @@ -156,7 +369,7 @@ def func(l: Array, r: Union[Scalar, Array]) -> Array: else: res_name = "out" - return pytest.param( + ctx = BinaryParamContext( func_name, func, left_sym, @@ -165,8 +378,8 @@ def func(l: Array, r: Union[Scalar, Array]) -> Array: right_strat, right_is_scalar, res_name, - id=f"{func_name}({left_sym}, {right_sym})", ) + return pytest.param(ctx, id=ctx.id) op_name = func_to_op[elwise_func_name] params = [ @@ -182,108 +395,139 @@ def func(l: Array, r: Union[Scalar, Array]) -> Array: return params -def assert_binary_param_dtype( - func_name: str, +def binary_param_assert_dtype( + ctx: BinaryParamContext, left: Array, right: Union[Array, Scalar], - right_is_scalar: bool, res: Array, - res_name: str, expected: Optional[DataType] = None, ): - if right_is_scalar: + if ctx.right_is_scalar: in_dtypes = left.dtype else: - in_dtypes = (left.dtype, right.dtype) # type: ignore + in_dtypes = [left.dtype, right.dtype] # type: ignore ph.assert_dtype( - func_name, in_dtypes, res.dtype, expected, repr_name=f"{res_name}.dtype" + ctx.func_name, in_dtypes, res.dtype, expected, repr_name=f"{ctx.res_name}.dtype" ) -@pytest.mark.parametrize(unary_argnames, make_unary_params("abs", dh.numeric_dtypes)) +def binary_param_assert_shape( + ctx: BinaryParamContext, + left: Array, + right: Union[Array, Scalar], + res: Array, + expected: Optional[Shape] = None, +): + if ctx.right_is_scalar: + in_shapes = [left.shape] + else: + in_shapes = [left.shape, right.shape] # type: ignore + ph.assert_result_shape( + ctx.func_name, in_shapes, res.shape, expected, repr_name=f"{ctx.res_name}.shape" + ) + + +def binary_param_assert_against_refimpl( + ctx: BinaryParamContext, + left: Array, + right: Union[Array, Scalar], + res: Array, + op_sym: str, + refimpl: Callable[[T, T], T], + res_stype: Optional[ScalarType] = None, + filter_: Callable[[Scalar], bool] = default_filter, + strict_check: Optional[bool] = None, +): + expr_template = "({} " + op_sym + " {})={}" + if ctx.right_is_scalar: + right_scalar_assert_against_refimpl( + func_name=ctx.func_name, + left_sym=ctx.left_sym, + left=left, + right=right, + res_stype=res_stype, + res_name=ctx.res_name, + res=res, + refimpl=refimpl, + expr_template=expr_template, + filter_=filter_, + strict_check=strict_check, + ) + else: + binary_assert_against_refimpl( + func_name=ctx.func_name, + left_sym=ctx.left_sym, + left=left, + right_sym=ctx.right_sym, + right=right, + res_stype=res_stype, + res_name=ctx.res_name, + res=res, + refimpl=refimpl, + expr_template=expr_template, + filter_=filter_, + strict_check=strict_check, + ) + + +@pytest.mark.parametrize("ctx", make_unary_params("abs", xps.numeric_dtypes())) @given(data=st.data()) -def test_abs(func_name, func, strat, data): - x = data.draw(strat, label="x") +def test_abs(ctx, data): + x = data.draw(ctx.strat, label="x") + # abs of the smallest negative integer is out-of-scope if x.dtype in dh.int_dtypes: - # abs of the smallest representable negative integer is not defined - mask = xp.not_equal( - x, ah.full(x.shape, dh.dtype_ranges[x.dtype].min, dtype=x.dtype) - ) - x = x[mask] - out = func(x) - ph.assert_dtype(func_name, x.dtype, out.dtype) - ph.assert_shape(func_name, out.shape, x.shape) - assert ah.all( - ah.logical_not(ah.negative_mathematical_sign(out)) - ), f"out elements not all positively signed [{func_name}()]\n{out=}" - less_zero = ah.negative_mathematical_sign(x) - negx = ah.negative(x) - # abs(x) = -x for x < 0 - ah.assert_exactly_equal(out[less_zero], negx[less_zero]) - # abs(x) = x for x >= 0 - ah.assert_exactly_equal( - out[ah.logical_not(less_zero)], x[ah.logical_not(less_zero)] + assume(xp.all(x > dh.dtype_ranges[x.dtype].min)) + + out = ctx.func(x) + + ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) + ph.assert_shape(ctx.func_name, out.shape, x.shape) + unary_assert_against_refimpl( + ctx.func_name, + x, + out, + abs, # type: ignore + expr_template="abs({})={}", + filter_=lambda s: ( + s == float("infinity") or (math.isfinite(s) and s is not -0.0) + ), ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_acos(x): - res = xp.acos(x) - ph.assert_dtype("acos", x.dtype, res.dtype) - ph.assert_shape("acos", res.shape, x.shape) - ONE = ah.one(x.shape, x.dtype) - # Here (and elsewhere), should technically be res.dtype, but this is the - # same as x.dtype, as tested by the type_promotion tests. - PI = ah.π(x.shape, x.dtype) - ZERO = ah.zero(x.shape, x.dtype) - domain = ah.inrange(x, -ONE, ONE) - codomain = ah.inrange(res, ZERO, PI) - # acos maps [-1, 1] to [0, pi]. Values outside this domain are mapped to - # nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + out = xp.acos(x) + ph.assert_dtype("acos", x.dtype, out.dtype) + ph.assert_shape("acos", out.shape, x.shape) + unary_assert_against_refimpl( + "acos", x, out, math.acos, filter_=lambda s: default_filter(s) and -1 <= s <= 1 + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_acosh(x): - res = xp.acosh(x) - ph.assert_dtype("acosh", x.dtype, res.dtype) - ph.assert_shape("acosh", res.shape, x.shape) - ONE = ah.one(x.shape, x.dtype) - INFINITY = ah.infinity(x.shape, x.dtype) - ZERO = ah.zero(x.shape, x.dtype) - domain = ah.inrange(x, ONE, INFINITY) - codomain = ah.inrange(res, ZERO, INFINITY) - # acosh maps [-1, inf] to [0, inf]. Values outside this domain are mapped - # to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) - - -@pytest.mark.parametrize(binary_argnames, make_binary_params("add", dh.numeric_dtypes)) + out = xp.acosh(x) + ph.assert_dtype("acosh", x.dtype, out.dtype) + ph.assert_shape("acosh", out.shape, x.shape) + unary_assert_against_refimpl( + "acosh", x, out, math.acosh, filter_=lambda s: default_filter(s) and s >= 1 + ) + + +@pytest.mark.parametrize("ctx,", make_binary_params("add", xps.numeric_dtypes())) @given(data=st.data()) -def test_add( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_add(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) try: - res = func(left, right) + res = ctx.func(left, right) except OverflowError: reject() - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - if not right_is_scalar: - # add is commutative - expected = func(right, left) - ah.assert_exactly_equal(res, expected) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl(ctx, left, right, res, "+", operator.add) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -291,13 +535,9 @@ def test_asin(x): out = xp.asin(x) ph.assert_dtype("asin", x.dtype, out.dtype) ph.assert_shape("asin", out.shape, x.shape) - ONE = ah.one(x.shape, x.dtype) - PI = ah.π(x.shape, x.dtype) - domain = ah.inrange(x, -ONE, ONE) - codomain = ah.inrange(out, -PI / 2, PI / 2) - # asin maps [-1, 1] to [-pi/2, pi/2]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl( + "asin", x, out, math.asin, filter_=lambda s: default_filter(s) and -1 <= s <= 1 + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -305,12 +545,7 @@ def test_asinh(x): out = xp.asinh(x) ph.assert_dtype("asinh", x.dtype, out.dtype) ph.assert_shape("asinh", out.shape, x.shape) - INFINITY = ah.infinity(x.shape, x.dtype) - domain = ah.inrange(x, -INFINITY, INFINITY) - codomain = ah.inrange(out, -INFINITY, INFINITY) - # asinh maps [-inf, inf] to [-inf, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("asinh", x, out, math.asinh) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -318,55 +553,15 @@ def test_atan(x): out = xp.atan(x) ph.assert_dtype("atan", x.dtype, out.dtype) ph.assert_shape("atan", out.shape, x.shape) - INFINITY = ah.infinity(x.shape, x.dtype) - PI = ah.π(x.shape, x.dtype) - domain = ah.inrange(x, -INFINITY, INFINITY) - codomain = ah.inrange(out, -PI / 2, PI / 2) - # atan maps [-inf, inf] to [-pi/2, pi/2]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("atan", x, out, math.atan) @given(*hh.two_mutual_arrays(dh.float_dtypes)) def test_atan2(x1, x2): out = xp.atan2(x1, x2) - ph.assert_dtype("atan2", (x1.dtype, x2.dtype), out.dtype) - ph.assert_result_shape("atan2", (x1.shape, x2.shape), out.shape) - INFINITY1 = ah.infinity(x1.shape, x1.dtype) - INFINITY2 = ah.infinity(x2.shape, x2.dtype) - PI = ah.π(out.shape, out.dtype) - domainx1 = ah.inrange(x1, -INFINITY1, INFINITY1) - domainx2 = ah.inrange(x2, -INFINITY2, INFINITY2) - # codomain = ah.inrange(out, -PI, PI, 1e-5) - codomain = ah.inrange(out, -PI, PI) - # atan2 maps [-inf, inf] x [-inf, inf] to [-pi, pi]. Values outside - # this domain are mapped to nan, which is already tested in the special - # cases. - ah.assert_exactly_equal(ah.logical_and(domainx1, domainx2), codomain) - # From the spec: - # - # The mathematical signs of `x1_i` and `x2_i` determine the quadrant of - # each element-wise out. The quadrant (i.e., branch) is chosen such - # that each element-wise out is the signed angle in radians between the - # ray ending at the origin and passing through the point `(1,0)` and the - # ray ending at the origin and passing through the point `(x2_i, x1_i)`. - - # This is equivalent to atan2(x1, x2) has the same sign as x1 when x2 is - # finite. - pos_x1 = ah.positive_mathematical_sign(x1) - neg_x1 = ah.negative_mathematical_sign(x1) - pos_x2 = ah.positive_mathematical_sign(x2) - neg_x2 = ah.negative_mathematical_sign(x2) - pos_out = ah.positive_mathematical_sign(out) - neg_out = ah.negative_mathematical_sign(out) - ah.assert_exactly_equal( - ah.logical_or(ah.logical_and(pos_x1, pos_x2), ah.logical_and(pos_x1, neg_x2)), - pos_out, - ) - ah.assert_exactly_equal( - ah.logical_or(ah.logical_and(neg_x1, pos_x2), ah.logical_and(neg_x1, neg_x2)), - neg_out, - ) + ph.assert_dtype("atan2", [x1.dtype, x2.dtype], out.dtype) + ph.assert_result_shape("atan2", [x1.shape, x2.shape], out.shape) + binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -374,291 +569,139 @@ def test_atanh(x): out = xp.atanh(x) ph.assert_dtype("atanh", x.dtype, out.dtype) ph.assert_shape("atanh", out.shape, x.shape) - ONE = ah.one(x.shape, x.dtype) - INFINITY = ah.infinity(x.shape, x.dtype) - domain = ah.inrange(x, -ONE, ONE) - codomain = ah.inrange(out, -INFINITY, INFINITY) - # atanh maps [-1, 1] to [-inf, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl( + "atanh", + x, + out, + math.atanh, + filter_=lambda s: default_filter(s) and -1 <= s <= 1, + ) @pytest.mark.parametrize( - binary_argnames, make_binary_params("bitwise_and", dh.bool_and_all_int_dtypes) + "ctx", make_binary_params("bitwise_and", boolean_and_all_integer_dtypes()) ) @given(data=st.data()) -def test_bitwise_and( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) - - res = func(left, right) - - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - if not right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, res.shape, shape, repr_name=f"{res_name}.shape") - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - - # Compare against the Python & operator. - if res.dtype == xp.bool: - for idx in sh.ndindex(res.shape): - s_left = bool(_left[idx]) - s_right = bool(_right[idx]) - s_res = bool(res[idx]) - assert (s_left and s_right) == s_res - else: - for idx in sh.ndindex(res.shape): - s_left = int(_left[idx]) - s_right = int(_right[idx]) - s_res = int(res[idx]) - s_and = ah.int_to_dtype( - s_left & s_right, - dh.dtype_nbits[res.dtype], - dh.dtype_signed[res.dtype], - ) - assert s_and == s_res +def test_bitwise_and(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) + + res = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + if left.dtype == xp.bool: + refimpl = operator.and_ + else: + refimpl = lambda l, r: mock_int_dtype(l & r, res.dtype) + binary_param_assert_against_refimpl(ctx, left, right, res, "&", refimpl) @pytest.mark.parametrize( - binary_argnames, make_binary_params("bitwise_left_shift", dh.all_int_dtypes) + "ctx", make_binary_params("bitwise_left_shift", all_integer_dtypes()) ) @given(data=st.data()) -def test_bitwise_left_shift( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) - if right_is_scalar: +def test_bitwise_left_shift(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) + if ctx.right_is_scalar: assume(right >= 0) else: assume(not ah.any(ah.isnegative(right))) - res = func(left, right) - - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - if not right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, res.shape, shape, repr_name=f"{res_name}.shape") - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - - # Compare against the Python << operator. - for idx in sh.ndindex(res.shape): - s_left = int(_left[idx]) - s_right = int(_right[idx]) - s_res = int(res[idx]) - s_shift = ah.int_to_dtype( - # We avoid shifting very large ints - s_left << s_right if s_right < dh.dtype_nbits[res.dtype] else 0, - dh.dtype_nbits[res.dtype], - dh.dtype_signed[res.dtype], - ) - assert s_shift == s_res + res = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + nbits = res.dtype + binary_param_assert_against_refimpl( + ctx, left, right, res, "<<", lambda l, r: l << r if r < nbits else 0 + ) @pytest.mark.parametrize( - unary_argnames, make_unary_params("bitwise_invert", dh.bool_and_all_int_dtypes) + "ctx", make_unary_params("bitwise_invert", boolean_and_all_integer_dtypes()) ) @given(data=st.data()) -def test_bitwise_invert(func_name, func, strat, data): - x = data.draw(strat, label="x") - - out = func(x) - - ph.assert_dtype(func_name, x.dtype, out.dtype) - ph.assert_shape(func_name, out.shape, x.shape) - # Compare against the Python ~ operator. - if out.dtype == xp.bool: - for idx in sh.ndindex(out.shape): - s_x = bool(x[idx]) - s_out = bool(out[idx]) - assert (not s_x) == s_out +def test_bitwise_invert(ctx, data): + x = data.draw(ctx.strat, label="x") + + out = ctx.func(x) + + ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) + ph.assert_shape(ctx.func_name, out.shape, x.shape) + if x.dtype == xp.bool: + refimpl = operator.not_ else: - for idx in sh.ndindex(out.shape): - s_x = int(x[idx]) - s_out = int(out[idx]) - s_invert = ah.int_to_dtype( - ~s_x, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype] - ) - assert s_invert == s_out + refimpl = lambda s: mock_int_dtype(~s, x.dtype) + unary_assert_against_refimpl(ctx.func_name, x, out, refimpl, expr_template="~{}={}") @pytest.mark.parametrize( - binary_argnames, make_binary_params("bitwise_or", dh.bool_and_all_int_dtypes) + "ctx", make_binary_params("bitwise_or", boolean_and_all_integer_dtypes()) ) @given(data=st.data()) -def test_bitwise_or( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) - - res = func(left, right) - - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - if not right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, res.shape, shape, repr_name=f"{res_name}.shape") - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - - # Compare against the Python | operator. - if res.dtype == xp.bool: - for idx in sh.ndindex(res.shape): - s_left = bool(_left[idx]) - s_right = bool(_right[idx]) - s_res = bool(res[idx]) - assert (s_left or s_right) == s_res - else: - for idx in sh.ndindex(res.shape): - s_left = int(_left[idx]) - s_right = int(_right[idx]) - s_res = int(res[idx]) - s_or = ah.int_to_dtype( - s_left | s_right, - dh.dtype_nbits[res.dtype], - dh.dtype_signed[res.dtype], - ) - assert s_or == s_res +def test_bitwise_or(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) + + res = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + if left.dtype == xp.bool: + refimpl = operator.or_ + else: + refimpl = lambda l, r: mock_int_dtype(l | r, res.dtype) + binary_param_assert_against_refimpl(ctx, left, right, res, "|", refimpl) @pytest.mark.parametrize( - binary_argnames, make_binary_params("bitwise_right_shift", dh.all_int_dtypes) + "ctx", make_binary_params("bitwise_right_shift", all_integer_dtypes()) ) @given(data=st.data()) -def test_bitwise_right_shift( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) - if right_is_scalar: +def test_bitwise_right_shift(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) + if ctx.right_is_scalar: assume(right >= 0) else: assume(not ah.any(ah.isnegative(right))) - res = func(left, right) + res = ctx.func(left, right) - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - if not right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape( - "bitwise_right_shift", res.shape, shape, repr_name=f"{res_name}.shape" - ) - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - - # Compare against the Python >> operator. - for idx in sh.ndindex(res.shape): - s_left = int(_left[idx]) - s_right = int(_right[idx]) - s_res = int(res[idx]) - s_shift = ah.int_to_dtype( - s_left >> s_right, dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype] - ) - assert s_shift == s_res + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl( + ctx, left, right, res, ">>", lambda l, r: mock_int_dtype(l >> r, res.dtype) + ) @pytest.mark.parametrize( - binary_argnames, make_binary_params("bitwise_xor", dh.bool_and_all_int_dtypes) + "ctx", make_binary_params("bitwise_xor", boolean_and_all_integer_dtypes()) ) @given(data=st.data()) -def test_bitwise_xor( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) - - res = func(left, right) - - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - if not right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, res.shape, shape, repr_name=f"{res_name}.shape") - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - - # Compare against the Python ^ operator. - if res.dtype == xp.bool: - for idx in sh.ndindex(res.shape): - s_left = bool(_left[idx]) - s_right = bool(_right[idx]) - s_res = bool(res[idx]) - assert (s_left ^ s_right) == s_res - else: - for idx in sh.ndindex(res.shape): - s_left = int(_left[idx]) - s_right = int(_right[idx]) - s_res = int(res[idx]) - s_xor = ah.int_to_dtype( - s_left ^ s_right, - dh.dtype_nbits[res.dtype], - dh.dtype_signed[res.dtype], - ) - assert s_xor == s_res +def test_bitwise_xor(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) + + res = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + if left.dtype == xp.bool: + refimpl = operator.xor + else: + refimpl = lambda l, r: mock_int_dtype(l ^ r, res.dtype) + binary_param_assert_against_refimpl(ctx, left, right, res, "^", refimpl) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_ceil(x): - # This test is almost identical to test_floor() out = xp.ceil(x) ph.assert_dtype("ceil", x.dtype, out.dtype) ph.assert_shape("ceil", out.shape, x.shape) - finite = ah.isfinite(x) - ah.assert_integral(out[finite]) - assert ah.all(ah.less_equal(x[finite], out[finite])) - assert ah.all( - ah.less_equal(out[finite] - x[finite], ah.one(x[finite].shape, x.dtype)) - ) - integers = ah.isintegral(x) - ah.assert_exactly_equal(out[integers], x[integers]) + unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -666,13 +709,7 @@ def test_cos(x): out = xp.cos(x) ph.assert_dtype("cos", x.dtype, out.dtype) ph.assert_shape("cos", out.shape, x.shape) - ONE = ah.one(x.shape, x.dtype) - INFINITY = ah.infinity(x.shape, x.dtype) - domain = ah.inrange(x, -INFINITY, INFINITY, open=True) - codomain = ah.inrange(out, -ONE, ONE) - # cos maps (-inf, inf) to [-1, 1]. Values outside this domain are mapped - # to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("cos", x, out, math.cos) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -680,100 +717,58 @@ def test_cosh(x): out = xp.cosh(x) ph.assert_dtype("cosh", x.dtype, out.dtype) ph.assert_shape("cosh", out.shape, x.shape) - INFINITY = ah.infinity(x.shape, x.dtype) - domain = ah.inrange(x, -INFINITY, INFINITY) - codomain = ah.inrange(out, -INFINITY, INFINITY) - # cosh maps [-inf, inf] to [-inf, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("cosh", x, out, math.cosh) -@pytest.mark.parametrize(binary_argnames, make_binary_params("divide", dh.float_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("divide", xps.floating_dtypes())) @given(data=st.data()) -def test_divide( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) - - res = func(left, right) - - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - # There isn't much we can test here. The spec doesn't require any behavior - # beyond the special cases, and indeed, there aren't many mathematical - # properties of division that strictly hold for floating-point numbers. We - # could test that this does implement IEEE 754 division, but we don't yet - # have those sorts in general for this module. +def test_divide(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) + if ctx.right_is_scalar: + assume + + res = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl( + ctx, + left, + right, + res, + "/", + operator.truediv, + filter_=lambda s: math.isfinite(s) and s != 0, + ) -@pytest.mark.parametrize(binary_argnames, make_binary_params("equal", dh.all_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("equal", xps.scalar_dtypes())) @given(data=st.data()) -def test_equal( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) - - out = func(left, right) - - assert_binary_param_dtype( - func_name, left, right, right_is_scalar, out, res_name, xp.bool - ) - # NOTE: ah.assert_exactly_equal() itself uses ah.equal(), so we must be careful - # not to use it here. Otherwise, the test would be circular and - # meaningless. Instead, we implement this by iterating every element of - # the arrays and comparing them. The logic here is also used for the tests - # for the other elementwise functions that accept any input dtype but - # always return bool (greater(), greater_equal(), less(), less_equal(), - # and not_equal()). - if not right_is_scalar: - # First we broadcast the arrays so that they can be indexed uniformly. - # TODO: it should be possible to skip this step if we instead generate - # indices to x1 and x2 that correspond to the broadcasted shapes. This - # would avoid the dependence in this test on broadcast_to(). - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, out.shape, shape) - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) - - # Second, manually promote the dtypes. This is important. If the internal - # type promotion in ah.equal() is wrong, it will not be directly visible in - # the output type, but it can lead to wrong answers. For example, - # ah.equal(array(1.0, dtype=xp.float32), array(1.00000001, dtype=xp.float64)) will - # be wrong if the float64 is downcast to float32. # be wrong if the - # xp.float64 is downcast to float32. See the comment on - # test_elementwise_function_two_arg_bool_type_promotion() in - # test_type_promotion.py. The type promotion for ah.equal() is not *really* - # tested in that file, because doing so requires doing the consistency - - # check we do here rather than just checking the res dtype. +def test_equal(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) + + out = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # We manually promote the dtypes as incorrect internal type promotion + # could lead to false positives. For example + # + # >>> xp.equal( + # ... xp.asarray(1.0, dtype=xp.float32), + # ... xp.asarray(1.00000001, dtype=xp.float64), + # ... ) + # + # would erroneously be True if float64 downcasted to float32. promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = ah.asarray(_left, dtype=promoted_dtype) - _right = ah.asarray(_right, dtype=promoted_dtype) - - scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in sh.ndindex(shape): - x1_idx = _left[idx] - x2_idx = _right[idx] - out_idx = out[idx] - assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check - assert bool(out_idx) == (scalar_type(x1_idx) == scalar_type(x2_idx)) + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, "==", operator.eq, res_stype=bool + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -781,13 +776,7 @@ def test_exp(x): out = xp.exp(x) ph.assert_dtype("exp", x.dtype, out.dtype) ph.assert_shape("exp", out.shape, x.shape) - INFINITY = ah.infinity(x.shape, x.dtype) - ZERO = ah.zero(x.shape, x.dtype) - domain = ah.inrange(x, -INFINITY, INFINITY) - codomain = ah.inrange(out, ZERO, INFINITY) - # exp maps [-inf, inf] to [0, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("exp", x, out, math.exp) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -795,160 +784,78 @@ def test_expm1(x): out = xp.expm1(x) ph.assert_dtype("expm1", x.dtype, out.dtype) ph.assert_shape("expm1", out.shape, x.shape) - INFINITY = ah.infinity(x.shape, x.dtype) - NEGONE = -ah.one(x.shape, x.dtype) - domain = ah.inrange(x, -INFINITY, INFINITY) - codomain = ah.inrange(out, NEGONE, INFINITY) - # expm1 maps [-inf, inf] to [1, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl("expm1", x, out, math.expm1) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_floor(x): - # This test is almost identical to test_ceil out = xp.floor(x) ph.assert_dtype("floor", x.dtype, out.dtype) ph.assert_shape("floor", out.shape, x.shape) - finite = ah.isfinite(x) - ah.assert_integral(out[finite]) - assert ah.all(ah.less_equal(out[finite], x[finite])) - assert ah.all( - ah.less_equal(x[finite] - out[finite], ah.one(x[finite].shape, x.dtype)) - ) - integers = ah.isintegral(x) - ah.assert_exactly_equal(out[integers], x[integers]) + unary_assert_against_refimpl("floor", x, out, math.floor, strict_check=True) @pytest.mark.parametrize( - binary_argnames, make_binary_params("floor_divide", dh.numeric_dtypes) + "ctx", make_binary_params("floor_divide", xps.numeric_dtypes()) ) @given(data=st.data()) -def test_floor_divide( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat.filter(lambda x: not ah.any(x == 0)), label=left_sym) - right = data.draw(right_strat, label=right_sym) - if right_is_scalar: +def test_floor_divide(ctx, data): + left = data.draw( + ctx.left_strat.filter(lambda x: not ah.any(x == 0)), label=ctx.left_sym + ) + right = data.draw(ctx.right_strat, label=ctx.right_sym) + if ctx.right_is_scalar: assume(right != 0) else: assume(not ah.any(right == 0)) - res = func(left, right) - - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - if not right_is_scalar: - if dh.is_int_dtype(left.dtype): - # The spec does not specify the behavior for division by 0 for integer - # dtypes. A library may choose to raise an exception in this case, so - # we avoid passing it in entirely. - div = xp.divide( - ah.asarray(left, dtype=xp.float64), - ah.asarray(right, dtype=xp.float64), - ) - else: - div = xp.divide(left, right) + res = ctx.func(left, right) - # TODO: The spec doesn't clearly specify the behavior of floor_divide on - # infinities. See https://github.com/data-apis/array-api/issues/199. - finite = ah.isfinite(div) - ah.assert_integral(res[finite]) - # TODO: Test the exact output for floor_divide. + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv) -@pytest.mark.parametrize( - binary_argnames, make_binary_params("greater", dh.numeric_dtypes) -) +@pytest.mark.parametrize("ctx", make_binary_params("greater", xps.numeric_dtypes())) @given(data=st.data()) -def test_greater( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_greater(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = func(left, right) - - assert_binary_param_dtype( - func_name, left, right, right_is_scalar, out, res_name, xp.bool - ) - if not right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, out.shape, shape) - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) + out = ctx.func(left, right) + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # See test_equal note promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = ah.asarray(_left, dtype=promoted_dtype) - _right = ah.asarray(_right, dtype=promoted_dtype) - - scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in sh.ndindex(shape): - out_idx = out[idx] - x1_idx = _left[idx] - x2_idx = _right[idx] - assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check - assert bool(out_idx) == (scalar_type(x1_idx) > scalar_type(x2_idx)) + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, ">", operator.gt, res_stype=bool + ) @pytest.mark.parametrize( - binary_argnames, make_binary_params("greater_equal", dh.numeric_dtypes) + "ctx", make_binary_params("greater_equal", xps.numeric_dtypes()) ) @given(data=st.data()) -def test_greater_equal( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_greater_equal(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = func(left, right) - - assert_binary_param_dtype( - func_name, left, right, right_is_scalar, out, res_name, xp.bool - ) - if not right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, out.shape, shape) - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) + out = ctx.func(left, right) + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # See test_equal note promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = ah.asarray(_left, dtype=promoted_dtype) - _right = ah.asarray(_right, dtype=promoted_dtype) - - scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in sh.ndindex(shape): - out_idx = out[idx] - x1_idx = _left[idx] - x2_idx = _right[idx] - assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check - assert bool(out_idx) == (scalar_type(x1_idx) >= scalar_type(x2_idx)) + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, ">=", operator.ge, res_stype=bool + ) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) @@ -956,157 +863,73 @@ def test_isfinite(x): out = ah.isfinite(x) ph.assert_dtype("isfinite", x.dtype, out.dtype, xp.bool) ph.assert_shape("isfinite", out.shape, x.shape) - if dh.is_int_dtype(x.dtype): - ah.assert_exactly_equal(out, ah.true(x.shape)) - # Test that isfinite, isinf, and isnan are self-consistent. - inf = ah.logical_or(xp.isinf(x), ah.isnan(x)) - ah.assert_exactly_equal(out, ah.logical_not(inf)) - - # Test the exact value by comparing to the math version - if dh.is_float_dtype(x.dtype): - for idx in sh.ndindex(x.shape): - s = float(x[idx]) - assert bool(out[idx]) == math.isfinite(s) + unary_assert_against_refimpl("isfinite", x, out, math.isfinite, res_stype=bool) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_isinf(x): out = xp.isinf(x) - ph.assert_dtype("isfinite", x.dtype, out.dtype, xp.bool) ph.assert_shape("isinf", out.shape, x.shape) - - if dh.is_int_dtype(x.dtype): - ah.assert_exactly_equal(out, ah.false(x.shape)) - finite_or_nan = ah.logical_or(ah.isfinite(x), ah.isnan(x)) - ah.assert_exactly_equal(out, ah.logical_not(finite_or_nan)) - - # Test the exact value by comparing to the math version - if dh.is_float_dtype(x.dtype): - for idx in sh.ndindex(x.shape): - s = float(x[idx]) - assert bool(out[idx]) == math.isinf(s) + unary_assert_against_refimpl("isinf", x, out, math.isinf, res_stype=bool) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_isnan(x): out = ah.isnan(x) - ph.assert_dtype("isnan", x.dtype, out.dtype, xp.bool) ph.assert_shape("isnan", out.shape, x.shape) + unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool) - if dh.is_int_dtype(x.dtype): - ah.assert_exactly_equal(out, ah.false(x.shape)) - finite_or_inf = ah.logical_or(ah.isfinite(x), xp.isinf(x)) - ah.assert_exactly_equal(out, ah.logical_not(finite_or_inf)) - - # Test the exact value by comparing to the math version - if dh.is_float_dtype(x.dtype): - for idx in sh.ndindex(x.shape): - s = float(x[idx]) - assert bool(out[idx]) == math.isnan(s) - -@pytest.mark.parametrize(binary_argnames, make_binary_params("less", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("less", xps.numeric_dtypes())) @given(data=st.data()) -def test_less( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) - - out = func(left, right) +def test_less(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - assert_binary_param_dtype( - func_name, left, right, right_is_scalar, out, res_name, xp.bool - ) - if not right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, out.shape, shape) - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) + out = ctx.func(left, right) + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # See test_equal note promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = ah.asarray(_left, dtype=promoted_dtype) - _right = ah.asarray(_right, dtype=promoted_dtype) - - scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in sh.ndindex(shape): - x1_idx = _left[idx] - x2_idx = _right[idx] - out_idx = out[idx] - assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check - assert bool(out_idx) == (scalar_type(x1_idx) < scalar_type(x2_idx)) + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, "<", operator.lt, res_stype=bool + ) -@pytest.mark.parametrize( - binary_argnames, make_binary_params("less_equal", dh.numeric_dtypes) -) +@pytest.mark.parametrize("ctx", make_binary_params("less_equal", xps.numeric_dtypes())) @given(data=st.data()) -def test_less_equal( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_less_equal(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = func(left, right) - - assert_binary_param_dtype( - func_name, left, right, right_is_scalar, out, res_name, xp.bool - ) - if not right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, out.shape, shape) - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) + out = ctx.func(left, right) + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # See test_equal note promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = ah.asarray(_left, dtype=promoted_dtype) - _right = ah.asarray(_right, dtype=promoted_dtype) - - scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in sh.ndindex(shape): - x1_idx = _left[idx] - x2_idx = _right[idx] - out_idx = out[idx] - assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check - assert bool(out_idx) == (scalar_type(x1_idx) <= scalar_type(x2_idx)) + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, "<=", operator.le, res_stype=bool + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_log(x): out = xp.log(x) - ph.assert_dtype("log", x.dtype, out.dtype) ph.assert_shape("log", out.shape, x.shape) - - INFINITY = ah.infinity(x.shape, x.dtype) - ZERO = ah.zero(x.shape, x.dtype) - domain = ah.inrange(x, ZERO, INFINITY) - codomain = ah.inrange(out, -INFINITY, INFINITY) - # log maps [0, inf] to [-inf, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl( + "log", x, out, math.log, filter_=lambda s: default_filter(s) and s >= 1 + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1114,13 +937,9 @@ def test_log1p(x): out = xp.log1p(x) ph.assert_dtype("log1p", x.dtype, out.dtype) ph.assert_shape("log1p", out.shape, x.shape) - INFINITY = ah.infinity(x.shape, x.dtype) - NEGONE = -ah.one(x.shape, x.dtype) - codomain = ah.inrange(x, NEGONE, INFINITY) - domain = ah.inrange(out, -INFINITY, INFINITY) - # log1p maps [1, inf] to [-inf, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl( + "log1p", x, out, math.log1p, filter_=lambda s: default_filter(s) and s >= 1 + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1128,13 +947,9 @@ def test_log2(x): out = xp.log2(x) ph.assert_dtype("log2", x.dtype, out.dtype) ph.assert_shape("log2", out.shape, x.shape) - INFINITY = ah.infinity(x.shape, x.dtype) - ZERO = ah.zero(x.shape, x.dtype) - domain = ah.inrange(x, ZERO, INFINITY) - codomain = ah.inrange(out, -INFINITY, INFINITY) - # log2 maps [0, inf] to [-inf, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl( + "log2", x, out, math.log2, filter_=lambda s: default_filter(s) and s > 1 + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1142,35 +957,31 @@ def test_log10(x): out = xp.log10(x) ph.assert_dtype("log10", x.dtype, out.dtype) ph.assert_shape("log10", out.shape, x.shape) - INFINITY = ah.infinity(x.shape, x.dtype) - ZERO = ah.zero(x.shape, x.dtype) - domain = ah.inrange(x, ZERO, INFINITY) - codomain = ah.inrange(out, -INFINITY, INFINITY) - # log10 maps [0, inf] to [-inf, inf]. Values outside this domain are - # mapped to nan, which is already tested in the special cases. - ah.assert_exactly_equal(domain, codomain) + unary_assert_against_refimpl( + "log10", x, out, math.log10, filter_=lambda s: default_filter(s) and s > 0 + ) + + +def logaddexp(l: float, r: float) -> float: + return math.log(math.exp(l) + math.exp(r)) @given(*hh.two_mutual_arrays(dh.float_dtypes)) def test_logaddexp(x1, x2): out = xp.logaddexp(x1, x2) - ph.assert_dtype("logaddexp", (x1.dtype, x2.dtype), out.dtype) - # The spec doesn't require any behavior for this function. We could test - # that this is indeed an approximation of log(exp(x1) + exp(x2)), but we - # don't have tests for this sort of thing for any functions yet. + ph.assert_dtype("logaddexp", [x1.dtype, x2.dtype], out.dtype) + ph.assert_result_shape("logaddexp", [x1.shape, x2.shape], out.shape) + binary_assert_against_refimpl("logaddexp", x1, x2, out, logaddexp) @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_and(x1, x2): out = ah.logical_and(x1, x2) - ph.assert_dtype("logical_and", (x1.dtype, x2.dtype), out.dtype) - # See the comments in test_equal - shape = broadcast_shapes(x1.shape, x2.shape) - ph.assert_shape("logical_and", out.shape, shape) - _x1 = xp.broadcast_to(x1, shape) - _x2 = xp.broadcast_to(x2, shape) - for idx in sh.ndindex(shape): - assert out[idx] == (bool(_x1[idx]) and bool(_x2[idx])) + ph.assert_dtype("logical_and", [x1.dtype, x2.dtype], out.dtype) + ph.assert_result_shape("logical_and", [x1.shape, x2.shape], out.shape) + binary_assert_against_refimpl( + "logical_and", x1, x2, out, operator.and_, expr_template="({} and {})={}" + ) @given(xps.arrays(dtype=xp.bool, shape=hh.shapes())) @@ -1178,165 +989,102 @@ def test_logical_not(x): out = ah.logical_not(x) ph.assert_dtype("logical_not", x.dtype, out.dtype) ph.assert_shape("logical_not", out.shape, x.shape) - for idx in sh.ndindex(x.shape): - assert out[idx] == (not bool(x[idx])) + unary_assert_against_refimpl( + "logical_not", x, out, operator.not_, expr_template="(not {})={}" + ) @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_or(x1, x2): out = ah.logical_or(x1, x2) - ph.assert_dtype("logical_or", (x1.dtype, x2.dtype), out.dtype) - # See the comments in test_equal - shape = broadcast_shapes(x1.shape, x2.shape) - ph.assert_shape("logical_or", out.shape, shape) - _x1 = xp.broadcast_to(x1, shape) - _x2 = xp.broadcast_to(x2, shape) - for idx in sh.ndindex(shape): - assert out[idx] == (bool(_x1[idx]) or bool(_x2[idx])) + ph.assert_dtype("logical_or", [x1.dtype, x2.dtype], out.dtype) + ph.assert_result_shape("logical_or", [x1.shape, x2.shape], out.shape) + binary_assert_against_refimpl( + "logical_or", x1, x2, out, operator.or_, expr_template="({} or {})={}" + ) @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_xor(x1, x2): out = xp.logical_xor(x1, x2) - ph.assert_dtype("logical_xor", (x1.dtype, x2.dtype), out.dtype) - # See the comments in test_equal - shape = broadcast_shapes(x1.shape, x2.shape) - ph.assert_shape("logical_xor", out.shape, shape) - _x1 = xp.broadcast_to(x1, shape) - _x2 = xp.broadcast_to(x2, shape) - for idx in sh.ndindex(shape): - assert out[idx] == (bool(_x1[idx]) ^ bool(_x2[idx])) + ph.assert_dtype("logical_xor", [x1.dtype, x2.dtype], out.dtype) + ph.assert_result_shape("logical_xor", [x1.shape, x2.shape], out.shape) + binary_assert_against_refimpl( + "logical_xor", x1, x2, out, operator.xor, expr_template="({} ^ {})={}" + ) -@pytest.mark.parametrize( - binary_argnames, make_binary_params("multiply", dh.numeric_dtypes) -) +@pytest.mark.parametrize("ctx", make_binary_params("multiply", xps.numeric_dtypes())) @given(data=st.data()) -def test_multiply( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_multiply(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - res = func(left, right) + res = ctx.func(left, right) - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - if not right_is_scalar: - # multiply is commutative - expected = func(right, left) - ah.assert_exactly_equal(res, expected) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl(ctx, left, right, res, "*", operator.mul) +# TODO: clarify if uints are acceptable, adjust accordingly @pytest.mark.parametrize( - unary_argnames, make_unary_params("negative", dh.numeric_dtypes) + "ctx", make_unary_params("negative", xps.integer_dtypes() | xps.floating_dtypes()) ) @given(data=st.data()) -def test_negative(func_name, func, strat, data): - x = data.draw(strat, label="x") - - out = func(x) - - ph.assert_dtype(func_name, x.dtype, out.dtype) - ph.assert_shape(func_name, out.shape, x.shape) - - # Negation is an involution - ah.assert_exactly_equal(x, func(out)) +def test_negative(ctx, data): + x = data.draw(ctx.strat, label="x") + # negative of the smallest negative integer is out-of-scope + if x.dtype in dh.int_dtypes: + assume(xp.all(x > dh.dtype_ranges[x.dtype].min)) - mask = ah.isfinite(x) - if dh.is_int_dtype(x.dtype): - minval = dh.dtype_ranges[x.dtype][0] - if minval < 0: - # negative of the smallest representable negative integer is not defined - mask = xp.not_equal(x, ah.full(x.shape, minval, dtype=x.dtype)) + out = ctx.func(x) - # Additive inverse - y = xp.add(x[mask], out[mask]) - ah.assert_exactly_equal(y, ah.zero(x[mask].shape, x.dtype)) + ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) + ph.assert_shape(ctx.func_name, out.shape, x.shape) + unary_assert_against_refimpl( + ctx.func_name, x, out, operator.neg, expr_template="-({})={}" # type: ignore + ) -@pytest.mark.parametrize( - binary_argnames, make_binary_params("not_equal", dh.all_dtypes) -) +@pytest.mark.parametrize("ctx", make_binary_params("not_equal", xps.scalar_dtypes())) @given(data=st.data()) -def test_not_equal( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_not_equal(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = func(left, right) - - assert_binary_param_dtype( - func_name, left, right, right_is_scalar, out, res_name, xp.bool - ) - if not right_is_scalar: - # TODO: generate indices without broadcasting arrays (see test_equal comment) - - shape = broadcast_shapes(left.shape, right.shape) - ph.assert_shape(func_name, out.shape, shape) - _left = xp.broadcast_to(left, shape) - _right = xp.broadcast_to(right, shape) + out = ctx.func(left, right) + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # See test_equal note promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - _left = ah.asarray(_left, dtype=promoted_dtype) - _right = ah.asarray(_right, dtype=promoted_dtype) - - scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in sh.ndindex(shape): - out_idx = out[idx] - x1_idx = _left[idx] - x2_idx = _right[idx] - assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check - assert bool(out_idx) == (scalar_type(x1_idx) != scalar_type(x2_idx)) + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, "!=", operator.ne, res_stype=bool + ) -@pytest.mark.parametrize( - unary_argnames, make_unary_params("positive", dh.numeric_dtypes) -) +@pytest.mark.parametrize("ctx", make_unary_params("positive", xps.numeric_dtypes())) @given(data=st.data()) -def test_positive(func_name, func, strat, data): - x = data.draw(strat, label="x") +def test_positive(ctx, data): + x = data.draw(ctx.strat, label="x") - out = func(x) + out = ctx.func(x) - ph.assert_dtype(func_name, x.dtype, out.dtype) - ph.assert_shape(func_name, out.shape, x.shape) - # Positive does nothing - ah.assert_exactly_equal(out, x) + ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) + ph.assert_shape(ctx.func_name, out.shape, x.shape) + ph.assert_array(ctx.func_name, out, x) -@pytest.mark.parametrize(binary_argnames, make_binary_params("pow", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("pow", xps.numeric_dtypes())) @given(data=st.data()) -def test_pow( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) - if right_is_scalar: +def test_pow(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) + if ctx.right_is_scalar: if isinstance(right, int): assume(right >= 0) else: @@ -1344,85 +1092,50 @@ def test_pow( assume(xp.all(right >= 0)) try: - res = func(left, right) + res = ctx.func(left, right) except OverflowError: reject() - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - # There isn't much we can test here. The spec doesn't require any behavior - # beyond the special cases, and indeed, there aren't many mathematical - # properties of exponentiation that strictly hold for floating-point - # numbers. We could test that this does implement IEEE 754 pow, but we - # don't yet have those sorts in general for this module. + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl( + ctx, left, right, res, "**", math.pow, strict_check=False + ) -@pytest.mark.parametrize( - binary_argnames, make_binary_params("remainder", dh.numeric_dtypes) -) +@pytest.mark.parametrize("ctx", make_binary_params("remainder", xps.numeric_dtypes())) @given(data=st.data()) -def test_remainder( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) - if right_is_scalar: - out_dtype = left.dtype +def test_remainder(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) + if ctx.right_is_scalar: + assume(right != 0) else: - out_dtype = dh.result_type(left.dtype, right.dtype) - if dh.is_int_dtype(out_dtype): - if right_is_scalar: - assume(right != 0) - else: - assume(not ah.any(right == 0)) + assume(not ah.any(right == 0)) - res = func(left, right) + res = ctx.func(left, right) - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - # TODO: test results + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl(ctx, left, right, res, "%", operator.mod) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_round(x): out = xp.round(x) - ph.assert_dtype("round", x.dtype, out.dtype) - ph.assert_shape("round", out.shape, x.shape) + unary_assert_against_refimpl("round", x, out, round, strict_check=True) - # Test that the out is integral - finite = ah.isfinite(x) - ah.assert_integral(out[finite]) - - # round(x) should be the neaoutt integer to x. The case where there is a - # tie (round to even) is already handled by the special cases tests. - # This is the same strategy used in the mask in the - # test_round_special_cases_one_arg_two_integers_equally_close special - # cases test. - floor = xp.floor(x) - ceil = xp.ceil(x) - over = xp.subtract(x, floor) - under = xp.subtract(ceil, x) - round_down = ah.less(over, under) - round_up = ah.less(under, over) - ah.assert_exactly_equal(out[round_down], floor[round_down]) - ah.assert_exactly_equal(out[round_up], ceil[round_up]) - - -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(), elements=finite_kw)) def test_sign(x): out = xp.sign(x) ph.assert_dtype("sign", x.dtype, out.dtype) ph.assert_shape("sign", out.shape, x.shape) - # TODO + unary_assert_against_refimpl( + "sign", x, out, lambda s: math.copysign(1, s), filter_=lambda s: s != 0 + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1430,7 +1143,7 @@ def test_sin(x): out = xp.sin(x) ph.assert_dtype("sin", x.dtype, out.dtype) ph.assert_shape("sin", out.shape, x.shape) - # TODO + unary_assert_against_refimpl("sin", x, out, math.sin) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1438,7 +1151,7 @@ def test_sinh(x): out = xp.sinh(x) ph.assert_dtype("sinh", x.dtype, out.dtype) ph.assert_shape("sinh", out.shape, x.shape) - # TODO + unary_assert_against_refimpl("sinh", x, out, math.sinh) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) @@ -1446,6 +1159,9 @@ def test_square(x): out = xp.square(x) ph.assert_dtype("square", x.dtype, out.dtype) ph.assert_shape("square", out.shape, x.shape) + unary_assert_against_refimpl( + "square", x, out, lambda s: s ** 2, expr_template="{}²={}" + ) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1453,33 +1169,25 @@ def test_sqrt(x): out = xp.sqrt(x) ph.assert_dtype("sqrt", x.dtype, out.dtype) ph.assert_shape("sqrt", out.shape, x.shape) + unary_assert_against_refimpl( + "sqrt", x, out, math.sqrt, filter_=lambda s: default_filter(s) and s >= 0 + ) -@pytest.mark.parametrize( - binary_argnames, make_binary_params("subtract", dh.numeric_dtypes) -) +@pytest.mark.parametrize("ctx", make_binary_params("subtract", xps.numeric_dtypes())) @given(data=st.data()) -def test_subtract( - func_name, - func, - left_sym, - left_strat, - right_sym, - right_strat, - right_is_scalar, - res_name, - data, -): - left = data.draw(left_strat, label=left_sym) - right = data.draw(right_strat, label=right_sym) +def test_subtract(ctx, data): + left = data.draw(ctx.left_strat, label=ctx.left_sym) + right = data.draw(ctx.right_strat, label=ctx.right_sym) try: - res = func(left, right) + res = ctx.func(left, right) except OverflowError: reject() - assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name) - # TODO + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl(ctx, left, right, res, "-", operator.sub) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1487,7 +1195,7 @@ def test_tan(x): out = xp.tan(x) ph.assert_dtype("tan", x.dtype, out.dtype) ph.assert_shape("tan", out.shape, x.shape) - # TODO + unary_assert_against_refimpl("tan", x, out, math.tan) @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) @@ -1495,7 +1203,7 @@ def test_tanh(x): out = xp.tanh(x) ph.assert_dtype("tanh", x.dtype, out.dtype) ph.assert_shape("tanh", out.shape, x.shape) - # TODO + unary_assert_against_refimpl("tanh", x, out, math.tanh) @given(xps.arrays(dtype=hh.numeric_dtypes, shape=xps.array_shapes())) @@ -1503,8 +1211,4 @@ def test_trunc(x): out = xp.trunc(x) ph.assert_dtype("trunc", x.dtype, out.dtype) ph.assert_shape("trunc", out.shape, x.shape) - if dh.is_int_dtype(x.dtype): - ah.assert_exactly_equal(out, x) - else: - finite = ah.isfinite(x) - ah.assert_integral(out[finite]) + unary_assert_against_refimpl("trunc", x, out, math.trunc, strict_check=True) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index b6a66086..01c26d0c 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -8,7 +8,6 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from .algos import broadcast_shapes pytestmark = pytest.mark.ci @@ -134,7 +133,7 @@ def test_where(shapes, dtypes, data): out = xp.where(cond, x1, x2) - shape = broadcast_shapes(*shapes) + shape = sh.broadcast_shapes(*shapes) ph.assert_shape("where", out.shape, shape) # TODO: generate indices without broadcasting arrays _cond = xp.broadcast_to(cond, shape) diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index 9679eaac..5ceceb54 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -1,8 +1,8 @@ # TODO: disable if opted out, refactor things import math -import pytest from collections import Counter, defaultdict +import pytest from hypothesis import assume, given from . import _array_module as xp diff --git a/array_api_tests/test_sorting_functions.py b/array_api_tests/test_sorting_functions.py index ea375b57..7c5a1411 100644 --- a/array_api_tests/test_sorting_functions.py +++ b/array_api_tests/test_sorting_functions.py @@ -1,7 +1,7 @@ import math -import pytest from typing import Set +import pytest from hypothesis import given from hypothesis import strategies as st from hypothesis.control import assume diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index c955b570..c86111a0 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -1,7 +1,7 @@ import math -import pytest from typing import Optional +import pytest from hypothesis import assume, given from hypothesis import strategies as st from hypothesis.control import reject diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index b1e5a09b..575e9011 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -271,7 +271,7 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data): out = eval(expr, {"x": x, "s": s}) except OverflowError: reject() - ph.assert_dtype(op, (in_dtype, in_stype), out.dtype, out_dtype) + ph.assert_dtype(op, [in_dtype, in_stype], out.dtype, out_dtype) inplace_scalar_params: List[Param[str, str, DataType, ScalarType]] = [] @@ -307,7 +307,7 @@ def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data): reject() x = locals_["x"] assert x.dtype == dtype, f"{x.dtype=!s}, but should be {dtype}" - ph.assert_dtype(op, (dtype, in_stype), x.dtype, dtype, repr_name="x.dtype") + ph.assert_dtype(op, [dtype, in_stype], x.dtype, dtype, repr_name="x.dtype") if __name__ == "__main__": diff --git a/array_api_tests/typing.py b/array_api_tests/typing.py index 286ce21b..da8652ae 100644 --- a/array_api_tests/typing.py +++ b/array_api_tests/typing.py @@ -1,4 +1,4 @@ -from typing import Tuple, Type, Union, Any +from typing import Any, Tuple, Type, Union __all__ = [ "DataType", @@ -6,6 +6,8 @@ "ScalarType", "Array", "Shape", + "AtomicIndex", + "Index", "Param", ] @@ -14,4 +16,6 @@ ScalarType = Union[Type[bool], Type[int], Type[float]] Array = Any Shape = Tuple[int, ...] +AtomicIndex = Union[int, "ellipsis", slice] # noqa +Index = Union[AtomicIndex, Tuple[AtomicIndex, ...]] Param = Tuple