diff --git a/array_api_tests/meta/test_pytest_helpers.py b/array_api_tests/meta/test_pytest_helpers.py index a6851a15..17ed5534 100644 --- a/array_api_tests/meta/test_pytest_helpers.py +++ b/array_api_tests/meta/test_pytest_helpers.py @@ -5,18 +5,18 @@ def test_assert_dtype(): - ph.assert_dtype("promoted_func", [xp.uint8, xp.int8], xp.int16) + ph.assert_dtype("promoted_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.int16) with raises(AssertionError): - ph.assert_dtype("bad_func", [xp.uint8, xp.int8], xp.float32) - ph.assert_dtype("bool_func", [xp.uint8, xp.int8], xp.bool, xp.bool) - ph.assert_dtype("single_promoted_func", [xp.uint8], xp.uint8) - ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool) + ph.assert_dtype("bad_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.float32) + ph.assert_dtype("bool_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.bool, expected=xp.bool) + ph.assert_dtype("single_promoted_func", in_dtype=[xp.uint8], out_dtype=xp.uint8) + ph.assert_dtype("single_bool_func", in_dtype=[xp.uint8], out_dtype=xp.bool, expected=xp.bool) def test_assert_array_elements(): - ph.assert_array_elements("int zeros", xp.asarray(0), xp.asarray(0)) - ph.assert_array_elements("pos zeros", xp.asarray(0.0), xp.asarray(0.0)) + ph.assert_array_elements("int zeros", out=xp.asarray(0), expected=xp.asarray(0)) + ph.assert_array_elements("pos zeros", out=xp.asarray(0.0), expected=xp.asarray(0.0)) with raises(AssertionError): - ph.assert_array_elements("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0)) + ph.assert_array_elements("mixed sign zeros", out=xp.asarray(0.0), expected=xp.asarray(-0.0)) with raises(AssertionError): - ph.assert_array_elements("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0)) + ph.assert_array_elements("mixed sign zeros", out=xp.asarray(-0.0), expected=xp.asarray(0.0)) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 0eb34180..0d354c0b 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -82,10 +82,10 @@ def is_neg_zero(n: float) -> bool: def assert_dtype( func_name: str, + *, in_dtype: Union[DataType, Sequence[DataType]], out_dtype: DataType, expected: Optional[DataType] = None, - *, repr_name: str = "out.dtype", ): """ @@ -96,7 +96,7 @@ def assert_dtype( >>> x = xp.arange(5, dtype=xp.uint8) >>> out = xp.abs(x) - >>> assert_dtype('abs', x.dtype, out.dtype) + >>> assert_dtype('abs', in_dtype=x.dtype, out_dtype=out.dtype) is equivalent to @@ -108,7 +108,7 @@ def assert_dtype( >>> x1 = xp.arange(5, dtype=xp.uint8) >>> x2 = xp.arange(5, dtype=xp.uint16) >>> out = xp.add(x1, x2) - >>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype) + >>> assert_dtype('add', in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) is equivalent to @@ -119,7 +119,7 @@ def assert_dtype( >>> x = xp.arange(5, dtype=xp.int8) >>> out = xp.sum(x) >>> default_int = xp.asarray(0).dtype - >>> assert_dtype('sum', x, out.dtype, default_int) + >>> assert_dtype('sum', in_dtype=x, out_dtype=out.dtype, expected=default_int) """ in_dtypes = in_dtype if isinstance(in_dtype, Sequence) and not isinstance(in_dtype, str) else [in_dtype] @@ -135,13 +135,18 @@ def assert_dtype( assert out_dtype == expected, msg -def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType): +def assert_kw_dtype( + func_name: str, + *, + kw_dtype: DataType, + out_dtype: DataType, +): """ Assert the output dtype is the passed keyword dtype, e.g. >>> kw = {'dtype': xp.uint8} - >>> out = xp.ones(5, **kw) - >>> assert_kw_dtype('ones', kw['dtype'], out.dtype) + >>> out = xp.ones(5, kw=kw) + >>> assert_kw_dtype('ones', kw_dtype=kw['dtype'], out_dtype=out.dtype) """ f_kw_dtype = dh.dtype_to_name[kw_dtype] @@ -222,17 +227,17 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty def assert_shape( func_name: str, + *, out_shape: Union[int, Shape], expected: Union[int, Shape], - /, repr_name="out.shape", - **kw, + kw: dict = {}, ): """ Assert the output shape is as expected, e.g. >>> out = xp.ones((3, 3, 3)) - >>> assert_shape('ones', out.shape, (3, 3, 3)) + >>> assert_shape('ones', out_shape=out.shape, expected=(3, 3, 3)) """ if isinstance(out_shape, int): @@ -249,11 +254,10 @@ def assert_result_shape( func_name: str, in_shapes: Sequence[Shape], out_shape: Shape, - /, expected: Optional[Shape] = None, *, repr_name="out.shape", - **kw, + kw: dict = {}, ): """ Assert the output shape is as expected. @@ -262,7 +266,7 @@ def assert_result_shape( in_shapes, to test against out_shape, e.g. >>> out = xp.add(xp.ones((3, 1)), xp.ones((1, 3))) - >>> assert_shape('add', [(3, 1), (1, 3)], out.shape) + >>> assert_result_shape('add', in_shape=[(3, 1), (1, 3)], out_shape=out.shape) is equivalent to @@ -281,12 +285,12 @@ def assert_result_shape( def assert_keepdimable_shape( func_name: str, + *, in_shape: Shape, out_shape: Shape, axes: Tuple[int, ...], keepdims: bool, - /, - **kw, + kw: dict = {}, ): """ Assert the output shape from a keepdimable function is as expected, e.g. @@ -294,8 +298,8 @@ def assert_keepdimable_shape( >>> x = xp.asarray([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) >>> out1 = xp.max(x, keepdims=False) >>> out2 = xp.max(x, keepdims=True) - >>> assert_keepdimable_shape('max', x.shape, out1.shape, (0, 1), False) - >>> assert_keepdimable_shape('max', x.shape, out2.shape, (0, 1), True) + >>> assert_keepdimable_shape('max', in_shape=x.shape, out_shape=out1.shape, axes=(0, 1), keepdims=False) + >>> assert_keepdimable_shape('max', in_shape=x.shape, out_shape=out2.shape, axes=(0, 1), keepdims=True) is equivalent to @@ -307,19 +311,26 @@ def assert_keepdimable_shape( shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape)) else: shape = tuple(side for axis, side in enumerate(in_shape) if axis not in axes) - assert_shape(func_name, out_shape, shape, **kw) + assert_shape(func_name, out_shape=out_shape, expected=shape, kw=kw) def assert_0d_equals( - func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw + func_name: str, + *, + x_repr: str, + x_val: Array, + out_repr: str, + out_val: Array, + kw: dict = {}, ): """ Assert a 0d array is as expected, e.g. >>> x = xp.asarray([0, 1, 2]) - >>> res = xp.asarray(x, copy=True) + >>> kw = {'copy': True} + >>> res = xp.asarray(x, **kw) >>> res[0] = 42 - >>> assert_0d_equals('asarray', 'x[0]', x[0], 'x[0]', res[0]) + >>> assert_0d_equals('asarray', x_repr='x[0]', x_val=x[0], out_repr='x[0]', out_val=res[0], kw=kw) is equivalent to @@ -338,20 +349,20 @@ def assert_0d_equals( def assert_scalar_equals( func_name: str, + *, type_: ScalarType, idx: Shape, out: Scalar, expected: Scalar, - /, repr_name: str = "out", - **kw, + kw: dict = {}, ): """ Assert a 0d array, convered to a scalar, is as expected, e.g. >>> x = xp.ones(5, dtype=xp.uint8) >>> out = xp.sum(x) - >>> assert_scalar_equals('sum', int, (), int(out), 5) + >>> assert_scalar_equals('sum', type_int, out=(), out=int(out), expected=5) is equivalent to @@ -372,13 +383,18 @@ def assert_scalar_equals( def assert_fill( - func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw + func_name: str, + *, + fill_value: Scalar, + dtype: DataType, + out: Array, + kw: dict = {}, ): """ Assert all elements of an array is as expected, e.g. >>> out = xp.full(5, 42, dtype=xp.uint8) - >>> assert_fill('full', 42, xp.uint8, out, 5) + >>> assert_fill('full', fill_value=42, dtype=xp.uint8, out=out, kw=dict(shape=5)) is equivalent to @@ -408,14 +424,19 @@ def _assert_float_element(at_out: Array, at_expected: Array, msg: str): def assert_array_elements( - func_name: str, out: Array, expected: Array, /, *, out_repr: str = "out", **kw + func_name: str, + *, + out: Array, + expected: Array, + out_repr: str = "out", + kw: dict = {}, ): """ Assert array elements are (strictly) as expected, e.g. >>> x = xp.arange(5) >>> out = xp.asarray(x) - >>> assert_array_elements('asarray', out, x) + >>> assert_array_elements('asarray', out=out, expected=x) is equivalent to @@ -423,7 +444,7 @@ def assert_array_elements( """ dh.result_type(out.dtype, expected.dtype) # sanity check - assert_shape(func_name, out.shape, expected.shape, **kw) # sanity check + assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check f_func = f"[{func_name}({fmt_kw(kw)})]" if out.dtype in dh.float_dtypes: for idx in sh.ndindex(out.shape): diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 4a539fc7..79b0e7d3 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -89,11 +89,11 @@ def test_getitem(shape, dtype, data): out = x[key] - ph.assert_dtype("__getitem__", x.dtype, out.dtype) + ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype) _key = normalise_key(key, shape) - axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape) - ph.assert_shape("__getitem__", out.shape, out_shape) - out_zero_sided = any(side == 0 for side in out_shape) + axes_indices, expected_shape = get_indexed_axes_and_out_shape(_key, shape) + ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape) + out_zero_sided = any(side == 0 for side in expected_shape) if not zero_sided and not out_zero_sided: out_obj = [] for idx in product(*axes_indices): @@ -101,9 +101,9 @@ def test_getitem(shape, dtype, data): for i in idx: val = val[i] out_obj.append(val) - out_obj = sh.reshape(out_obj, out_shape) + out_obj = sh.reshape(out_obj, expected_shape) expected = xp.asarray(out_obj, dtype=dtype) - ph.assert_array_elements("__getitem__", out, expected) + ph.assert_array_elements("__getitem__", out=out, expected=expected) @given( @@ -131,8 +131,8 @@ def test_setitem(shape, dtypes, data): res = xp.asarray(x, copy=True) res[key] = value - ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype") - ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.shape") + ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype") + ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.shape") f_res = sh.fmt_idx("x", key) if isinstance(value, get_args(Scalar)): msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]" @@ -141,11 +141,15 @@ def test_setitem(shape, dtypes, data): else: assert res[key] == value, msg else: - ph.assert_array_elements("__setitem__", res[key], value, out_repr=f_res) + ph.assert_array_elements("__setitem__", out=res[key], expected=value, out_repr=f_res) unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices)) for idx in unaffected_indices: ph.assert_0d_equals( - "__setitem__", f"old {f_res}", x[idx], f"modified {f_res}", res[idx] + "__setitem__", + x_repr=f"old {f_res}", + x_val=x[idx], + out_repr=f"modified {f_res}", + out_val=res[idx], ) @@ -171,14 +175,14 @@ def test_getitem_masking(shape, data): out = x[key] - ph.assert_dtype("__getitem__", x.dtype, out.dtype) + ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype) if key.ndim == 0: - out_shape = (1,) if key else (0,) - out_shape += x.shape + expected_shape = (1,) if key else (0,) + expected_shape += x.shape else: size = int(xp.sum(xp.astype(key, xp.uint8))) - out_shape = (size,) + x.shape[key.ndim :] - ph.assert_shape("__getitem__", out.shape, out_shape) + expected_shape = (size,) + x.shape[key.ndim :] + ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape) if not any(s == 0 for s in key.shape): assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios out_indices = sh.ndindex(out.shape) @@ -187,10 +191,10 @@ def test_getitem_masking(shape, data): out_idx = next(out_indices) ph.assert_0d_equals( "__getitem__", - f"x[{x_idx}]", - x[x_idx], - f"out[{out_idx}]", - out[out_idx], + x_repr=f"x[{x_idx}]", + x_val=x[x_idx], + out_repr=f"out[{out_idx}]", + out_val=out[out_idx], ) @@ -205,27 +209,35 @@ def test_setitem_masking(shape, data): res = xp.asarray(x, copy=True) res[key] = value - ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype") - ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.dtype") + ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype") + ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.dtype") scalar_type = dh.get_scalar_type(x.dtype) for idx in sh.ndindex(x.shape): if key[idx]: if isinstance(value, get_args(Scalar)): ph.assert_scalar_equals( "__setitem__", - scalar_type, - idx, - scalar_type(res[idx]), - value, + type_=scalar_type, + idx=idx, + out=scalar_type(res[idx]), + expected=value, repr_name="modified x", ) else: ph.assert_0d_equals( - "__setitem__", "value", value, f"modified x[{idx}]", res[idx] + "__setitem__", + x_repr="value", + x_val=value, + out_repr=f"modified x[{idx}]", + out_val=res[idx] ) else: ph.assert_0d_equals( - "__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx] + "__setitem__", + x_repr=f"old x[{idx}]", + x_val=x[idx], + out_repr=f"modified x[{idx}]", + out_val=res[idx] ) diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 2ebd3b07..cc6acbbe 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -159,7 +159,7 @@ def test_arange(dtype, data): else: ph.assert_default_float("arange", out.dtype) else: - ph.assert_kw_dtype("arange", dtype, out.dtype) + ph.assert_kw_dtype("arange", kw_dtype=dtype, out_dtype=out.dtype) f_sig = ", ".join(str(n) for n in args) if len(kwargs) > 0: f_sig += f", {ph.fmt_kw(kwargs)}" @@ -189,7 +189,7 @@ def test_arange(dtype, data): if dh.is_int_dtype(_dtype): elements = list(r) assume(out_size == len(elements)) - ph.assert_array_elements("arange", out, xp.asarray(elements, dtype=_dtype)) + ph.assert_array_elements("arange", out=out, expected=xp.asarray(elements, dtype=_dtype)) else: assume(out_size == size) if out_size > 0: @@ -247,11 +247,11 @@ def test_asarray_scalars(shape, data): assert out.dtype in dtype_family, msg else: assert kw["dtype"] == _dtype # sanity check - ph.assert_kw_dtype("asarray", _dtype, out.dtype) - ph.assert_shape("asarray", out.shape, shape) + ph.assert_kw_dtype("asarray", kw_dtype=_dtype, out_dtype=out.dtype) + ph.assert_shape("asarray", out_shape=out.shape, expected=shape) for idx, v_expect in zip(sh.ndindex(out.shape), _obj): v = scalar_type(out[idx]) - ph.assert_scalar_equals("asarray", scalar_type, idx, v, v_expect, **kw) + ph.assert_scalar_equals("asarray", type_=scalar_type, idx=idx, out=v, expected=v_expect, kw=kw) def scalar_eq(s1: Scalar, s2: Scalar) -> bool: @@ -280,11 +280,11 @@ def test_asarray_arrays(shape, dtypes, data): dtype = kw.get("dtype", None) if dtype is None: - ph.assert_dtype("asarray", x.dtype, out.dtype) + ph.assert_dtype("asarray", in_dtype=x.dtype, out_dtype=out.dtype) else: - ph.assert_kw_dtype("asarray", dtype, out.dtype) - ph.assert_shape("asarray", out.shape, x.shape) - ph.assert_array_elements("asarray", out, x, **kw) + ph.assert_kw_dtype("asarray", kw_dtype=dtype, out_dtype=out.dtype) + ph.assert_shape("asarray", out_shape=out.shape, expected=x.shape) + ph.assert_array_elements("asarray", out=out, expected=x, kw=kw) copy = kw.get("copy", None) if copy is not None: stype = dh.get_scalar_type(x.dtype) @@ -301,7 +301,7 @@ def test_asarray_arrays(shape, dtypes, data): note(f"mutated {x=}") # sanity check ph.assert_scalar_equals( - "__setitem__", stype, idx, stype(x[idx]), value, repr_name="x" + "__setitem__", type_=stype, idx=idx, out=stype(x[idx]), expected=value, repr_name="x" ) new_out_value = stype(out[idx]) f_out = f"{sh.fmt_idx('out', idx)}={new_out_value}" @@ -321,8 +321,8 @@ def test_empty(shape, kw): if kw.get("dtype", None) is None: ph.assert_default_float("empty", out.dtype) else: - ph.assert_kw_dtype("empty", kw["dtype"], out.dtype) - ph.assert_shape("empty", out.shape, shape, shape=shape) + ph.assert_kw_dtype("empty", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("empty", out_shape=out.shape, expected=shape, kw=dict(shape=shape)) @given( @@ -332,10 +332,10 @@ def test_empty(shape, kw): def test_empty_like(x, kw): out = xp.empty_like(x, **kw) if kw.get("dtype", None) is None: - ph.assert_dtype("empty_like", x.dtype, out.dtype) + ph.assert_dtype("empty_like", in_dtype=x.dtype, out_dtype=out.dtype) else: - ph.assert_kw_dtype("empty_like", kw["dtype"], out.dtype) - ph.assert_shape("empty_like", out.shape, x.shape) + ph.assert_kw_dtype("empty_like", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("empty_like", out_shape=out.shape, expected=x.shape) @given( @@ -351,9 +351,9 @@ def test_eye(n_rows, n_cols, kw): if kw.get("dtype", None) is None: ph.assert_default_float("eye", out.dtype) else: - ph.assert_kw_dtype("eye", kw["dtype"], out.dtype) + ph.assert_kw_dtype("eye", kw_dtype=kw["dtype"], out_dtype=out.dtype) _n_cols = n_rows if n_cols is None else n_cols - ph.assert_shape("eye", out.shape, (n_rows, _n_cols), n_rows=n_rows, n_cols=n_cols) + ph.assert_shape("eye", out_shape=out.shape, expected=(n_rows, _n_cols), kw=dict(n_rows=n_rows, n_cols=n_cols)) f_func = f"[eye({n_rows=}, {n_cols=})]" for i in range(n_rows): for j in range(_n_cols): @@ -421,9 +421,9 @@ def test_full(shape, fill_value, kw): assert isinstance(fill_value, complex) # sanity check ph.assert_default_complex("full", out.dtype) else: - ph.assert_kw_dtype("full", kw["dtype"], out.dtype) - ph.assert_shape("full", out.shape, shape, shape=shape) - ph.assert_fill("full", fill_value, dtype, out, fill_value=fill_value) + ph.assert_kw_dtype("full", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("full", out_shape=out.shape, expected=shape, kw=dict(shape=shape)) + ph.assert_fill("full", fill_value=fill_value, dtype=dtype, out=out, kw=dict(fill_value=fill_value)) @st.composite @@ -444,11 +444,11 @@ def test_full_like(x, fill_value, kw): out = xp.full_like(x, fill_value, **kw) dtype = kw.get("dtype", None) or x.dtype if kw.get("dtype", None) is None: - ph.assert_dtype("full_like", x.dtype, out.dtype) + ph.assert_dtype("full_like", in_dtype=x.dtype, out_dtype=out.dtype) else: - ph.assert_kw_dtype("full_like", kw["dtype"], out.dtype) - ph.assert_shape("full_like", out.shape, x.shape) - ph.assert_fill("full_like", fill_value, dtype, out, fill_value=fill_value) + ph.assert_kw_dtype("full_like", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("full_like", out_shape=out.shape, expected=x.shape) + ph.assert_fill("full_like", fill_value=fill_value, dtype=dtype, out=out, kw=dict(fill_value=fill_value)) finite_kw = {"allow_nan": False, "allow_infinity": False} @@ -484,8 +484,8 @@ def test_linspace(num, dtype, endpoint, data): if dtype is None: ph.assert_default_float("linspace", out.dtype) else: - ph.assert_kw_dtype("linspace", dtype, out.dtype) - ph.assert_shape("linspace", out.shape, num, start=stop, stop=stop, num=num) + ph.assert_kw_dtype("linspace", kw_dtype=dtype, out_dtype=out.dtype) + ph.assert_shape("linspace", out_shape=out.shape, expected=num, kw=dict(start=start, stop=stop, num=num)) f_func = f"[linspace({start}, {stop}, {num})]" if num > 0: assert xp.equal( @@ -501,7 +501,7 @@ def test_linspace(num, dtype, endpoint, data): # the first num elements when endpoint=False expected = xp.linspace(start, stop, num + 1, dtype=dtype, endpoint=True) expected = expected[:-1] - ph.assert_array_elements("linspace", out, expected) + ph.assert_array_elements("linspace", out=out, expected=expected) @given(dtype=xps.numeric_dtypes(), data=st.data()) @@ -524,7 +524,7 @@ def test_meshgrid(dtype, data): assert math.prod(math.prod(x.shape) for x in arrays) <= hh.MAX_ARRAY_SIZE out = xp.meshgrid(*arrays) for i, x in enumerate(out): - ph.assert_dtype("meshgrid", dtype, x.dtype, repr_name=f"out[{i}].dtype") + ph.assert_dtype("meshgrid", in_dtype=dtype, out_dtype=x.dtype, repr_name=f"out[{i}].dtype") def make_one(dtype: DataType) -> Scalar: @@ -542,10 +542,11 @@ def test_ones(shape, kw): if kw.get("dtype", None) is None: ph.assert_default_float("ones", out.dtype) else: - ph.assert_kw_dtype("ones", kw["dtype"], out.dtype) - ph.assert_shape("ones", out.shape, shape, shape=shape) + ph.assert_kw_dtype("ones", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("ones", out_shape=out.shape, expected=shape, + kw={'shape': shape, **kw}) dtype = kw.get("dtype", None) or dh.default_float - ph.assert_fill("ones", make_one(dtype), dtype, out) + ph.assert_fill("ones", fill_value=make_one(dtype), dtype=dtype, out=out, kw=kw) @given( @@ -555,12 +556,13 @@ def test_ones(shape, kw): def test_ones_like(x, kw): out = xp.ones_like(x, **kw) if kw.get("dtype", None) is None: - ph.assert_dtype("ones_like", x.dtype, out.dtype) + ph.assert_dtype("ones_like", in_dtype=x.dtype, out_dtype=out.dtype) else: - ph.assert_kw_dtype("ones_like", kw["dtype"], out.dtype) - ph.assert_shape("ones_like", out.shape, x.shape) + ph.assert_kw_dtype("ones_like", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("ones_like", out_shape=out.shape, expected=x.shape, kw=kw) dtype = kw.get("dtype", None) or x.dtype - ph.assert_fill("ones_like", make_one(dtype), dtype, out) + ph.assert_fill("ones_like", fill_value=make_one(dtype), dtype=dtype, + out=out, kw=kw) def make_zero(dtype: DataType) -> Scalar: @@ -576,12 +578,13 @@ def make_zero(dtype: DataType) -> Scalar: def test_zeros(shape, kw): out = xp.zeros(shape, **kw) if kw.get("dtype", None) is None: - ph.assert_default_float("zeros", out.dtype) + ph.assert_default_float("zeros", out_dtype=out.dtype) else: - ph.assert_kw_dtype("zeros", kw["dtype"], out.dtype) - ph.assert_shape("zeros", out.shape, shape, shape=shape) + ph.assert_kw_dtype("zeros", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("zeros", out_shape=out.shape, expected=shape, kw={'shape': shape, **kw}) dtype = kw.get("dtype", None) or dh.default_float - ph.assert_fill("zeros", make_zero(dtype), dtype, out) + ph.assert_fill("zeros", fill_value=make_zero(dtype), dtype=dtype, out=out, + kw=kw) @given( @@ -591,9 +594,11 @@ def test_zeros(shape, kw): def test_zeros_like(x, kw): out = xp.zeros_like(x, **kw) if kw.get("dtype", None) is None: - ph.assert_dtype("zeros_like", x.dtype, out.dtype) + ph.assert_dtype("zeros_like", in_dtype=x.dtype, out_dtype=out.dtype) else: - ph.assert_kw_dtype("zeros_like", kw["dtype"], out.dtype) - ph.assert_shape("zeros_like", out.shape, x.shape) + ph.assert_kw_dtype("zeros_like", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("zeros_like", out_shape=out.shape, expected=x.shape, + kw=kw) dtype = kw.get("dtype", None) or x.dtype - ph.assert_fill("zeros_like", make_zero(dtype), dtype, out) + ph.assert_fill("zeros_like", fill_value=make_zero(dtype), dtype=dtype, + out=out, kw=kw) diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 5cd409ce..6bf6ed7a 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -58,8 +58,8 @@ def test_astype(x_dtype, dtype, kw, data): out = xp.astype(x, dtype, **kw) - ph.assert_kw_dtype("astype", dtype, out.dtype) - ph.assert_shape("astype", out.shape, x.shape) + ph.assert_kw_dtype("astype", kw_dtype=dtype, out_dtype=out.dtype) + ph.assert_shape("astype", out_shape=out.shape, expected=x.shape, kw=kw) # TODO: test values # TODO: test copy @@ -75,16 +75,19 @@ def test_broadcast_arrays(shapes, data): out = xp.broadcast_arrays(*arrays) - out_shape = sh.broadcast_shapes(*shapes) + expected_shape = sh.broadcast_shapes(*shapes) for i, x in enumerate(arrays): ph.assert_dtype( - "broadcast_arrays", x.dtype, out[i].dtype, repr_name=f"out[{i}].dtype" + "broadcast_arrays", + in_dtype=x.dtype, + out_dtype=out[i].dtype, + repr_name=f"out[{i}].dtype" ) ph.assert_result_shape( "broadcast_arrays", - shapes, - out[i].shape, - out_shape, + in_shapes=shapes, + out_shape=out[i].shape, + expected=expected_shape, repr_name=f"out[{i}].shape", ) # TODO: test values @@ -101,8 +104,8 @@ def test_broadcast_to(x, data): out = xp.broadcast_to(x, shape) - ph.assert_dtype("broadcast_to", x.dtype, out.dtype) - ph.assert_shape("broadcast_to", out.shape, shape) + ph.assert_dtype("broadcast_to", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("broadcast_to", out_shape=out.shape, expected=shape) # TODO: test values @@ -177,4 +180,4 @@ def test_iinfo(dtype): @given(hh.mutually_promotable_dtypes(None)) def test_result_type(dtypes): out = xp.result_type(*dtypes) - ph.assert_dtype("result_type", dtypes, out, repr_name="out") + ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out") diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index cc07e6b4..291f2bfb 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -295,7 +295,7 @@ def test_matmul(x1, x2): else: res = _array_module.matmul(x1, x2) - ph.assert_dtype("matmul", [x1.dtype, x2.dtype], res.dtype) + ph.assert_dtype("matmul", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype) if len(x1.shape) == len(x2.shape) == 1: assert res.shape == () @@ -585,7 +585,7 @@ def test_tensordot(dtypes, shape, data): out = xp.tensordot(x1, x2, axes=len(shape)) - ph.assert_dtype("tensordot", dtypes, out.dtype) + ph.assert_dtype("tensordot", in_dtype=dtypes, out_dtype=out.dtype) # TODO: assert shape and elements @@ -642,7 +642,7 @@ def test_vecdot(dtypes, shape, data): out = xp.vecdot(x1, x2, **kw) - ph.assert_dtype("vecdot", dtypes, out.dtype) + ph.assert_dtype("vecdot", in_dtype=dtypes, out_dtype=out.dtype) # TODO: assert shape and elements diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index c5f19633..8f48dd9d 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -32,11 +32,11 @@ def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: def assert_array_ndindex( func_name: str, x: Array, + *, x_indices: Iterable[Union[int, Shape]], out: Array, out_indices: Iterable[Union[int, Shape]], - /, - **kw, + kw: dict = {}, ): msg_suffix = f" [{func_name}({ph.fmt_kw(kw)})]\n {x=}\n{out=}" for x_idx, out_idx in zip(x_indices, out_indices): @@ -77,7 +77,7 @@ def test_concat(dtypes, base_shape, data): out = xp.concat(arrays, **kw) - ph.assert_dtype("concat", dtypes, out.dtype) + ph.assert_dtype("concat", in_dtype=dtypes, out_dtype=out.dtype) shapes = tuple(x.shape for x in arrays) if _axis is None: @@ -88,7 +88,7 @@ def test_concat(dtypes, base_shape, data): for other_shape in shapes[1:]: shape[_axis] += other_shape[_axis] shape = tuple(shape) - ph.assert_result_shape("concat", shapes, out.shape, shape, **kw) + ph.assert_result_shape("concat", in_shapes=shapes, out_shape=out.shape, expected=shape, kw=kw) if _axis is None: out_indices = (i for i in range(math.prod(out.shape))) @@ -97,11 +97,11 @@ def test_concat(dtypes, base_shape, data): out_i = next(out_indices) ph.assert_0d_equals( "concat", - f"x{x_num}[{x_idx}]", - x[x_idx], - f"out[{out_i}]", - out[out_i], - **kw, + x_repr=f"x{x_num}[{x_idx}]", + x_val=x[x_idx], + out_repr=f"out[{out_i}]", + out_val=out[out_i], + kw=kw, ) else: out_indices = sh.ndindex(out.shape) @@ -113,11 +113,11 @@ def test_concat(dtypes, base_shape, data): out_idx = next(out_indices) ph.assert_0d_equals( "concat", - f"x{x_num}[{f_idx}][{x_idx}]", - indexed_x[x_idx], - f"out[{out_idx}]", - out[out_idx], - **kw, + x_repr=f"x{x_num}[{f_idx}][{x_idx}]", + x_val=indexed_x[x_idx], + out_repr=f"out[{out_idx}]", + out_val=out[out_idx], + kw=kw, ) @@ -136,16 +136,16 @@ def test_expand_dims(x, axis): out = xp.expand_dims(x, axis=axis) - ph.assert_dtype("expand_dims", x.dtype, out.dtype) + ph.assert_dtype("expand_dims", in_dtype=x.dtype, out_dtype=out.dtype) shape = [side for side in x.shape] index = axis if axis >= 0 else x.ndim + axis + 1 shape.insert(index, 1) shape = tuple(shape) - ph.assert_result_shape("expand_dims", [x.shape], out.shape, shape) + ph.assert_result_shape("expand_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape) assert_array_ndindex( - "expand_dims", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape) + "expand_dims", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape) ) @@ -174,16 +174,16 @@ def test_squeeze(x, data): out = xp.squeeze(x, axis) - ph.assert_dtype("squeeze", x.dtype, out.dtype) + ph.assert_dtype("squeeze", in_dtype=x.dtype, out_dtype=out.dtype) shape = [] for i, side in enumerate(x.shape): if i not in axes: shape.append(side) shape = tuple(shape) - ph.assert_result_shape("squeeze", [x.shape], out.shape, shape, axis=axis) + ph.assert_result_shape("squeeze", in_shapes=[x.shape], out_shape=out.shape, expected=shape, kw=dict(axis=axis)) - assert_array_ndindex("squeeze", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape)) + assert_array_ndindex("squeeze", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape)) @given( @@ -201,12 +201,13 @@ def test_flip(x, data): out = xp.flip(x, **kw) - ph.assert_dtype("flip", x.dtype, out.dtype) + ph.assert_dtype("flip", in_dtype=x.dtype, out_dtype=out.dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) for indices in sh.axes_ndindex(x.shape, _axes): reverse_indices = indices[::-1] - assert_array_ndindex("flip", x, indices, out, reverse_indices) + assert_array_ndindex("flip", x, x_indices=indices, out=out, + out_indices=reverse_indices, kw=kw) @given( @@ -223,18 +224,19 @@ def test_flip(x, data): def test_permute_dims(x, axes): out = xp.permute_dims(x, axes) - ph.assert_dtype("permute_dims", x.dtype, out.dtype) + ph.assert_dtype("permute_dims", in_dtype=x.dtype, out_dtype=out.dtype) shape = [None for _ in range(len(axes))] for i, dim in enumerate(axes): side = x.shape[dim] shape[i] = side shape = tuple(shape) - ph.assert_result_shape("permute_dims", [x.shape], out.shape, shape, axes=axes) + ph.assert_result_shape("permute_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape, kw=dict(axes=axes)) indices = list(sh.ndindex(x.shape)) permuted_indices = [tuple(idx[axis] for axis in axes) for idx in indices] - assert_array_ndindex("permute_dims", x, indices, out, permuted_indices) + assert_array_ndindex("permute_dims", x, x_indices=indices, out=out, + out_indices=permuted_indices) @st.composite @@ -257,7 +259,7 @@ def test_reshape(x, data): out = xp.reshape(x, shape) - ph.assert_dtype("reshape", x.dtype, out.dtype) + ph.assert_dtype("reshape", in_dtype=x.dtype, out_dtype=out.dtype) _shape = list(shape) if any(side == -1 for side in shape): @@ -265,9 +267,9 @@ def test_reshape(x, data): rsize = math.prod(shape) * -1 _shape[shape.index(-1)] = size / rsize _shape = tuple(_shape) - ph.assert_result_shape("reshape", [x.shape], out.shape, _shape, shape=shape) + ph.assert_result_shape("reshape", in_shapes=[x.shape], out_shape=out.shape, expected=_shape, kw=dict(shape=shape)) - assert_array_ndindex("reshape", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape)) + assert_array_ndindex("reshape", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape)) def roll_ndindex(shape: Shape, shifts: Tuple[int], axes: Tuple[int]) -> Iterator[Shape]: @@ -301,21 +303,21 @@ def test_roll(x, data): kw = {"shift": shift, **kw} # for error messages - ph.assert_dtype("roll", x.dtype, out.dtype) + ph.assert_dtype("roll", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_result_shape("roll", [x.shape], out.shape) + ph.assert_result_shape("roll", in_shapes=[x.shape], out_shape=out.shape, kw=kw) if kw.get("axis", None) is None: assert isinstance(shift, int) # sanity check indices = list(sh.ndindex(x.shape)) shifted_indices = deque(indices) shifted_indices.rotate(-shift) - assert_array_ndindex("roll", x, indices, out, shifted_indices, **kw) + assert_array_ndindex("roll", x, x_indices=indices, out=out, out_indices=shifted_indices, kw=kw) else: shifts = (shift,) if isinstance(shift, int) else shift axes = sh.normalise_axis(kw["axis"], x.ndim) shifted_indices = roll_ndindex(x.shape, shifts, axes) - assert_array_ndindex("roll", x, sh.ndindex(x.shape), out, shifted_indices, **kw) + assert_array_ndindex("roll", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=shifted_indices, kw=kw) @given( @@ -336,7 +338,7 @@ def test_stack(shape, dtypes, kw, data): out = xp.stack(arrays, **kw) - ph.assert_dtype("stack", dtypes, out.dtype) + ph.assert_dtype("stack", in_dtype=dtypes, out_dtype=out.dtype) axis = kw.get("axis", 0) _axis = axis if axis >= 0 else len(shape) + axis + 1 @@ -344,7 +346,7 @@ def test_stack(shape, dtypes, kw, data): _shape.insert(_axis, len(arrays)) _shape = tuple(_shape) ph.assert_result_shape( - "stack", tuple(x.shape for x in arrays), out.shape, _shape, **kw + "stack", in_shapes=tuple(x.shape for x in arrays), out_shape=out.shape, expected=_shape, kw=kw ) out_indices = sh.ndindex(out.shape) @@ -356,9 +358,9 @@ def test_stack(shape, dtypes, kw, data): out_idx = next(out_indices) ph.assert_0d_equals( "stack", - f"x{x_num}[{f_idx}][{x_idx}]", - indexed_x[x_idx], - f"out[{out_idx}]", - out[out_idx], - **kw, + x_repr=f"x{x_num}[{f_idx}][{x_idx}]", + x_val=indexed_x[x_idx], + out_repr=f"out[{out_idx}]", + out_val=out[out_idx], + kw=kw, ) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 691494cb..086d5b75 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -622,7 +622,7 @@ def binary_param_assert_dtype( else: in_dtypes = [left.dtype, right.dtype] # type: ignore ph.assert_dtype( - ctx.func_name, in_dtypes, res.dtype, expected, repr_name=f"{ctx.res_name}.dtype" + ctx.func_name, in_dtype=in_dtypes, out_dtype=res.dtype, expected=expected, repr_name=f"{ctx.res_name}.dtype" ) @@ -638,7 +638,7 @@ def binary_param_assert_shape( else: in_shapes = [left.shape, right.shape] # type: ignore ph.assert_result_shape( - ctx.func_name, in_shapes, res.shape, expected, repr_name=f"{ctx.res_name}.shape" + ctx.func_name, in_shapes=in_shapes, out_shape=res.shape, expected=expected, repr_name=f"{ctx.res_name}.shape" ) @@ -699,8 +699,8 @@ def test_abs(ctx, data): if x.dtype in dh.complex_dtypes: assert out.dtype == dh.dtype_components[x.dtype] else: - ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) - ph.assert_shape(ctx.func_name, out.shape, x.shape) + ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( ctx.func_name, x, @@ -717,8 +717,8 @@ def test_abs(ctx, data): @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_acos(x): out = xp.acos(x) - ph.assert_dtype("acos", x.dtype, out.dtype) - ph.assert_shape("acos", out.shape, x.shape) + ph.assert_dtype("acos", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("acos", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( "acos", x, out, math.acos, filter_=lambda s: default_filter(s) and -1 <= s <= 1 ) @@ -727,8 +727,8 @@ def test_acos(x): @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_acosh(x): out = xp.acosh(x) - ph.assert_dtype("acosh", x.dtype, out.dtype) - ph.assert_shape("acosh", out.shape, x.shape) + ph.assert_dtype("acosh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("acosh", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( "acosh", x, out, math.acosh, filter_=lambda s: default_filter(s) and s >= 1 ) @@ -753,8 +753,8 @@ def test_add(ctx, data): @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_asin(x): out = xp.asin(x) - ph.assert_dtype("asin", x.dtype, out.dtype) - ph.assert_shape("asin", out.shape, x.shape) + ph.assert_dtype("asin", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("asin", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( "asin", x, out, math.asin, filter_=lambda s: default_filter(s) and -1 <= s <= 1 ) @@ -763,32 +763,32 @@ def test_asin(x): @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_asinh(x): out = xp.asinh(x) - ph.assert_dtype("asinh", x.dtype, out.dtype) - ph.assert_shape("asinh", out.shape, x.shape) + ph.assert_dtype("asinh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("asinh", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("asinh", x, out, math.asinh) @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_atan(x): out = xp.atan(x) - ph.assert_dtype("atan", x.dtype, out.dtype) - ph.assert_shape("atan", out.shape, x.shape) + ph.assert_dtype("atan", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("atan", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("atan", x, out, math.atan) @given(*hh.two_mutual_arrays(dh.float_dtypes)) def test_atan2(x1, x2): out = xp.atan2(x1, x2) - ph.assert_dtype("atan2", [x1.dtype, x2.dtype], out.dtype) - ph.assert_result_shape("atan2", [x1.shape, x2.shape], out.shape) + ph.assert_dtype("atan2", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) + ph.assert_result_shape("atan2", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2) @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_atanh(x): out = xp.atanh(x) - ph.assert_dtype("atanh", x.dtype, out.dtype) - ph.assert_shape("atanh", out.shape, x.shape) + ph.assert_dtype("atanh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("atanh", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( "atanh", x, @@ -848,8 +848,8 @@ def test_bitwise_invert(ctx, data): out = ctx.func(x) - ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) - ph.assert_shape(ctx.func_name, out.shape, x.shape) + ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) if x.dtype == xp.bool: refimpl = operator.not_ else: @@ -919,8 +919,8 @@ def test_bitwise_xor(ctx, data): @given(xps.arrays(dtype=xps.real_dtypes(), shape=hh.shapes())) def test_ceil(x): out = xp.ceil(x) - ph.assert_dtype("ceil", x.dtype, out.dtype) - ph.assert_shape("ceil", out.shape, x.shape) + ph.assert_dtype("ceil", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("ceil", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True) @@ -929,24 +929,24 @@ def test_ceil(x): @given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) def test_conj(x): out = xp.conj(x) - ph.assert_dtype("conj", x.dtype, out.dtype) - ph.assert_shape("conj", out.shape, x.shape) + ph.assert_dtype("conj", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("conj", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate")) @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_cos(x): out = xp.cos(x) - ph.assert_dtype("cos", x.dtype, out.dtype) - ph.assert_shape("cos", out.shape, x.shape) + ph.assert_dtype("cos", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("cos", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("cos", x, out, math.cos) @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_cosh(x): out = xp.cosh(x) - ph.assert_dtype("cosh", x.dtype, out.dtype) - ph.assert_shape("cosh", out.shape, x.shape) + ph.assert_dtype("cosh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("cosh", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("cosh", x, out, math.cosh) @@ -1006,24 +1006,24 @@ def test_equal(ctx, data): @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_exp(x): out = xp.exp(x) - ph.assert_dtype("exp", x.dtype, out.dtype) - ph.assert_shape("exp", out.shape, x.shape) + ph.assert_dtype("exp", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("exp", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("exp", x, out, math.exp) @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_expm1(x): out = xp.expm1(x) - ph.assert_dtype("expm1", x.dtype, out.dtype) - ph.assert_shape("expm1", out.shape, x.shape) + ph.assert_dtype("expm1", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("expm1", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("expm1", x, out, math.expm1) @given(xps.arrays(dtype=xps.real_dtypes(), shape=hh.shapes())) def test_floor(x): out = xp.floor(x) - ph.assert_dtype("floor", x.dtype, out.dtype) - ph.assert_shape("floor", out.shape, x.shape) + ph.assert_dtype("floor", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("floor", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("floor", x, out, math.floor, strict_check=True) @@ -1091,32 +1091,32 @@ def test_greater_equal(ctx, data): @given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) def test_imag(x): out = xp.imag(x) - ph.assert_dtype("imag", x.dtype, out.dtype, dh.dtype_components[x.dtype]) - ph.assert_shape("imag", out.shape, x.shape) + ph.assert_dtype("imag", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype]) + ph.assert_shape("imag", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag")) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_isfinite(x): out = xp.isfinite(x) - ph.assert_dtype("isfinite", x.dtype, out.dtype, xp.bool) - ph.assert_shape("isfinite", out.shape, x.shape) + ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + ph.assert_shape("isfinite", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("isfinite", x, out, math.isfinite, res_stype=bool) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_isinf(x): out = xp.isinf(x) - ph.assert_dtype("isfinite", x.dtype, out.dtype, xp.bool) - ph.assert_shape("isinf", out.shape, x.shape) + ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + ph.assert_shape("isinf", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("isinf", x, out, math.isinf, res_stype=bool) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_isnan(x): out = xp.isnan(x) - ph.assert_dtype("isnan", x.dtype, out.dtype, xp.bool) - ph.assert_shape("isnan", out.shape, x.shape) + ph.assert_dtype("isnan", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + ph.assert_shape("isnan", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool) @@ -1163,8 +1163,8 @@ def test_less_equal(ctx, data): @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_log(x): out = xp.log(x) - ph.assert_dtype("log", x.dtype, out.dtype) - ph.assert_shape("log", out.shape, x.shape) + ph.assert_dtype("log", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("log", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( "log", x, out, math.log, filter_=lambda s: default_filter(s) and s >= 1 ) @@ -1173,8 +1173,8 @@ def test_log(x): @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_log1p(x): out = xp.log1p(x) - ph.assert_dtype("log1p", x.dtype, out.dtype) - ph.assert_shape("log1p", out.shape, x.shape) + ph.assert_dtype("log1p", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("log1p", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( "log1p", x, out, math.log1p, filter_=lambda s: default_filter(s) and s >= 1 ) @@ -1183,8 +1183,8 @@ def test_log1p(x): @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_log2(x): out = xp.log2(x) - ph.assert_dtype("log2", x.dtype, out.dtype) - ph.assert_shape("log2", out.shape, x.shape) + ph.assert_dtype("log2", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("log2", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( "log2", x, out, math.log2, filter_=lambda s: default_filter(s) and s > 1 ) @@ -1193,8 +1193,8 @@ def test_log2(x): @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_log10(x): out = xp.log10(x) - ph.assert_dtype("log10", x.dtype, out.dtype) - ph.assert_shape("log10", out.shape, x.shape) + ph.assert_dtype("log10", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("log10", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( "log10", x, out, math.log10, filter_=lambda s: default_filter(s) and s > 0 ) @@ -1207,16 +1207,16 @@ def logaddexp(l: float, r: float) -> float: @given(*hh.two_mutual_arrays(dh.float_dtypes)) def test_logaddexp(x1, x2): out = xp.logaddexp(x1, x2) - ph.assert_dtype("logaddexp", [x1.dtype, x2.dtype], out.dtype) - ph.assert_result_shape("logaddexp", [x1.shape, x2.shape], out.shape) + ph.assert_dtype("logaddexp", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) + ph.assert_result_shape("logaddexp", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) binary_assert_against_refimpl("logaddexp", x1, x2, out, logaddexp) @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_and(x1, x2): out = xp.logical_and(x1, x2) - ph.assert_dtype("logical_and", [x1.dtype, x2.dtype], out.dtype) - ph.assert_result_shape("logical_and", [x1.shape, x2.shape], out.shape) + ph.assert_dtype("logical_and", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) + ph.assert_result_shape("logical_and", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) binary_assert_against_refimpl( "logical_and", x1, x2, out, operator.and_, expr_template="({} and {})={}" ) @@ -1225,8 +1225,8 @@ def test_logical_and(x1, x2): @given(xps.arrays(dtype=xp.bool, shape=hh.shapes())) def test_logical_not(x): out = xp.logical_not(x) - ph.assert_dtype("logical_not", x.dtype, out.dtype) - ph.assert_shape("logical_not", out.shape, x.shape) + ph.assert_dtype("logical_not", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("logical_not", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( "logical_not", x, out, operator.not_, expr_template="(not {})={}" ) @@ -1235,8 +1235,8 @@ def test_logical_not(x): @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_or(x1, x2): out = xp.logical_or(x1, x2) - ph.assert_dtype("logical_or", [x1.dtype, x2.dtype], out.dtype) - ph.assert_result_shape("logical_or", [x1.shape, x2.shape], out.shape) + ph.assert_dtype("logical_or", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) + ph.assert_result_shape("logical_or", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) binary_assert_against_refimpl( "logical_or", x1, x2, out, operator.or_, expr_template="({} or {})={}" ) @@ -1245,8 +1245,8 @@ def test_logical_or(x1, x2): @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_xor(x1, x2): out = xp.logical_xor(x1, x2) - ph.assert_dtype("logical_xor", [x1.dtype, x2.dtype], out.dtype) - ph.assert_result_shape("logical_xor", [x1.shape, x2.shape], out.shape) + ph.assert_dtype("logical_xor", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) + ph.assert_result_shape("logical_xor", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) binary_assert_against_refimpl( "logical_xor", x1, x2, out, operator.xor, expr_template="({} ^ {})={}" ) @@ -1276,8 +1276,8 @@ def test_negative(ctx, data): out = ctx.func(x) - ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) - ph.assert_shape(ctx.func_name, out.shape, x.shape) + ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( ctx.func_name, x, out, operator.neg, expr_template="-({})={}" # type: ignore ) @@ -1310,9 +1310,9 @@ def test_positive(ctx, data): out = ctx.func(x) - ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) - ph.assert_shape(ctx.func_name, out.shape, x.shape) - ph.assert_array_elements(ctx.func_name, out, x) + ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) + ph.assert_array_elements(ctx.func_name, out=out, expected=x) @pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes)) @@ -1342,8 +1342,8 @@ def test_pow(ctx, data): @given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) def test_real(x): out = xp.real(x) - ph.assert_dtype("real", x.dtype, out.dtype, dh.dtype_components[x.dtype]) - ph.assert_shape("real", out.shape, x.shape) + ph.assert_dtype("real", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype]) + ph.assert_shape("real", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("real", x, out, operator.attrgetter("real")) @@ -1367,16 +1367,16 @@ def test_remainder(ctx, data): @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_round(x): out = xp.round(x) - ph.assert_dtype("round", x.dtype, out.dtype) - ph.assert_shape("round", out.shape, x.shape) + ph.assert_dtype("round", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("round", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("round", x, out, round, strict_check=True) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(), elements=finite_kw)) def test_sign(x): out = xp.sign(x) - ph.assert_dtype("sign", x.dtype, out.dtype) - ph.assert_shape("sign", out.shape, x.shape) + ph.assert_dtype("sign", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("sign", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( "sign", x, out, lambda s: math.copysign(1, s), filter_=lambda s: s != 0 ) @@ -1385,24 +1385,24 @@ def test_sign(x): @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_sin(x): out = xp.sin(x) - ph.assert_dtype("sin", x.dtype, out.dtype) - ph.assert_shape("sin", out.shape, x.shape) + ph.assert_dtype("sin", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("sin", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("sin", x, out, math.sin) @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_sinh(x): out = xp.sinh(x) - ph.assert_dtype("sinh", x.dtype, out.dtype) - ph.assert_shape("sinh", out.shape, x.shape) + ph.assert_dtype("sinh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("sinh", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("sinh", x, out, math.sinh) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_square(x): out = xp.square(x) - ph.assert_dtype("square", x.dtype, out.dtype) - ph.assert_shape("square", out.shape, x.shape) + ph.assert_dtype("square", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("square", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( "square", x, out, lambda s: s**2, expr_template="{}²={}" ) @@ -1411,8 +1411,8 @@ def test_square(x): @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_sqrt(x): out = xp.sqrt(x) - ph.assert_dtype("sqrt", x.dtype, out.dtype) - ph.assert_shape("sqrt", out.shape, x.shape) + ph.assert_dtype("sqrt", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("sqrt", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( "sqrt", x, out, math.sqrt, filter_=lambda s: default_filter(s) and s >= 0 ) @@ -1437,22 +1437,22 @@ def test_subtract(ctx, data): @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_tan(x): out = xp.tan(x) - ph.assert_dtype("tan", x.dtype, out.dtype) - ph.assert_shape("tan", out.shape, x.shape) + ph.assert_dtype("tan", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("tan", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("tan", x, out, math.tan) @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_tanh(x): out = xp.tanh(x) - ph.assert_dtype("tanh", x.dtype, out.dtype) - ph.assert_shape("tanh", out.shape, x.shape) + ph.assert_dtype("tanh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("tanh", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("tanh", x, out, math.tanh) @given(xps.arrays(dtype=xps.real_dtypes(), shape=xps.array_shapes())) def test_trunc(x): out = xp.trunc(x) - ph.assert_dtype("trunc", x.dtype, out.dtype) - ph.assert_shape("trunc", out.shape, x.shape) + ph.assert_dtype("trunc", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("trunc", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("trunc", x, out, math.trunc, strict_check=True) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 41f5a77c..b09e1379 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -30,13 +30,14 @@ def test_argmax(x, data): ), label="kw", ) + keepdims = kw.get("keepdims", False) out = xp.argmax(x, **kw) ph.assert_default_index("argmax", out.dtype) axes = sh.normalise_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "argmax", x.shape, out.shape, axes, kw.get("keepdims", False), **kw + "argmax", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw ) scalar_type = dh.get_scalar_type(x.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): @@ -46,7 +47,8 @@ def test_argmax(x, data): s = scalar_type(x[idx]) elements.append(s) expected = max(range(len(elements)), key=elements.__getitem__) - ph.assert_scalar_equals("argmax", int, out_idx, max_i, expected) + ph.assert_scalar_equals("argmax", type_=int, idx=out_idx, out=max_i, + expected=expected, kw=kw) @given( @@ -65,13 +67,14 @@ def test_argmin(x, data): ), label="kw", ) + keepdims = kw.get("keepdims", False) out = xp.argmin(x, **kw) ph.assert_default_index("argmin", out.dtype) axes = sh.normalise_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "argmin", x.shape, out.shape, axes, kw.get("keepdims", False), **kw + "argmin", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw ) scalar_type = dh.get_scalar_type(x.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): @@ -81,7 +84,7 @@ def test_argmin(x, data): s = scalar_type(x[idx]) elements.append(s) expected = min(range(len(elements)), key=elements.__getitem__) - ph.assert_scalar_equals("argmin", int, out_idx, min_i, expected) + ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected) @pytest.mark.data_dependent_shapes @@ -138,7 +141,7 @@ def test_where(shapes, dtypes, data): out = xp.where(cond, x1, x2) shape = sh.broadcast_shapes(*shapes) - ph.assert_shape("where", out.shape, shape) + ph.assert_shape("where", out_shape=out.shape, expected=shape) # TODO: generate indices without broadcasting arrays _cond = xp.broadcast_to(cond, shape) _x1 = xp.broadcast_to(x1, shape) @@ -146,9 +149,17 @@ def test_where(shapes, dtypes, data): for idx in sh.ndindex(shape): if _cond[idx]: ph.assert_0d_equals( - "where", f"_x1[{idx}]", _x1[idx], f"out[{idx}]", out[idx] + "where", + x_repr=f"_x1[{idx}]", + x_val=_x1[idx], + out_repr=f"out[{idx}]", + out_val=out[idx] ) else: ph.assert_0d_equals( - "where", f"_x2[{idx}]", _x2[idx], f"out[{idx}]", out[idx] + "where", + x_repr=f"_x2[{idx}]", + x_val=_x2[idx], + out_repr=f"out[{idx}]", + out_val=out[idx] ) diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index 193087d9..92d39739 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -26,7 +26,7 @@ def test_unique_all(x): assert hasattr(out, "counts") ph.assert_dtype( - "unique_all", x.dtype, out.values.dtype, repr_name="out.values.dtype" + "unique_all", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype" ) ph.assert_default_index( "unique_all", out.indices.dtype, repr_name="out.indices.dtype" @@ -43,8 +43,8 @@ def test_unique_all(x): ), f"{out.indices.shape=}, but should be {out.values.shape=}" ph.assert_shape( "unique_all", - out.inverse_indices.shape, - x.shape, + out_shape=out.inverse_indices.shape, + expected=x.shape, repr_name="out.inverse_indices.shape", ) assert ( @@ -122,7 +122,7 @@ def test_unique_counts(x): assert hasattr(out, "values") assert hasattr(out, "counts") ph.assert_dtype( - "unique_counts", x.dtype, out.values.dtype, repr_name="out.values.dtype" + "unique_counts", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype" ) ph.assert_default_index( "unique_counts", out.counts.dtype, repr_name="out.counts.dtype" @@ -169,7 +169,7 @@ def test_unique_inverse(x): assert hasattr(out, "values") assert hasattr(out, "inverse_indices") ph.assert_dtype( - "unique_inverse", x.dtype, out.values.dtype, repr_name="out.values.dtype" + "unique_inverse", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype" ) ph.assert_default_index( "unique_inverse", @@ -178,8 +178,8 @@ def test_unique_inverse(x): ) ph.assert_shape( "unique_inverse", - out.inverse_indices.shape, - x.shape, + out_shape=out.inverse_indices.shape, + expected=x.shape, repr_name="out.inverse_indices.shape", ) scalar_type = dh.get_scalar_type(out.values.dtype) @@ -219,7 +219,7 @@ def test_unique_inverse(x): @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) def test_unique_values(x): out = xp.unique_values(x) - ph.assert_dtype("unique_values", x.dtype, out.dtype) + ph.assert_dtype("unique_values", in_dtype=x.dtype, out_dtype=out.dtype) scalar_type = dh.get_scalar_type(x.dtype) distinct = set(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) vals_idx = {} diff --git a/array_api_tests/test_sorting_functions.py b/array_api_tests/test_sorting_functions.py index 69149c1b..14e65802 100644 --- a/array_api_tests/test_sorting_functions.py +++ b/array_api_tests/test_sorting_functions.py @@ -22,8 +22,7 @@ def assert_scalar_in_set( idx: Shape, out: Scalar, set_: Set[Scalar], - /, - **kw, + kw={}, ): out_repr = "out" if idx == () else f"out[{idx}]" if cmath.isnan(out): @@ -57,7 +56,7 @@ def test_argsort(x, data): out = xp.argsort(x, **kw) ph.assert_default_index("argsort", out.dtype) - ph.assert_shape("argsort", out.shape, x.shape, **kw) + ph.assert_shape("argsort", out_shape=out.shape, expected=x.shape, kw=kw) axis = kw.get("axis", -1) axes = sh.normalise_axis(axis, x.ndim) scalar_type = dh.get_scalar_type(x.dtype) @@ -69,7 +68,7 @@ def test_argsort(x, data): ) if kw.get("stable", True): for idx, o in zip(indices, sorders): - ph.assert_scalar_equals("argsort", int, idx, int(out[idx]), o, **kw) + ph.assert_scalar_equals("argsort", type_=int, idx=idx, out=int(out[idx]), expected=o, kw=kw) else: idx_elements = dict(zip(indices, elements)) idx_orders = dict(zip(indices, orders)) @@ -84,11 +83,11 @@ def test_argsort(x, data): out_o = int(out[idx]) if len(expected_orders) == 1: ph.assert_scalar_equals( - "argsort", int, idx, out_o, expected_orders[0], **kw + "argsort", type_=int, idx=idx, out=out_o, expected=expected_orders[0], kw=kw ) else: assert_scalar_in_set( - "argsort", idx, out_o, set(expected_orders), **kw + "argsort", idx=idx, out=out_o, set_=set(expected_orders), kw=kw ) @@ -116,8 +115,8 @@ def test_sort(x, data): out = xp.sort(x, **kw) - ph.assert_dtype("sort", out.dtype, x.dtype) - ph.assert_shape("sort", out.shape, x.shape, **kw) + ph.assert_dtype("sort", out_dtype=out.dtype, in_dtype=x.dtype) + ph.assert_shape("sort", out_shape=out.shape, expected=x.shape, kw=kw) axis = kw.get("axis", -1) axes = sh.normalise_axis(axis, x.ndim) scalar_type = dh.get_scalar_type(x.dtype) @@ -132,9 +131,9 @@ def test_sort(x, data): # TODO: error message when unstable should not imply just one idx ph.assert_0d_equals( "sort", - f"x[{x_idx}]", - x[x_idx], - f"out[{out_idx}]", - out[out_idx], - **kw, + x_repr=f"x[{x_idx}]", + x_val=x[x_idx], + out_repr=f"out[{out_idx}]", + out_val=out[out_idx], + kw=kw, ) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 2e4167ce..4aefb5b9 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1254,7 +1254,7 @@ def test_binary(func_name, func, case, x1, x2, data): res = func(x1, x2) # sanity check - ph.assert_result_shape(func_name, [x1.shape, x2.shape], res.shape, result_shape) + ph.assert_result_shape(func_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape, expected=result_shape) good_example = False for l_idx, r_idx, o_idx in all_indices: @@ -1306,7 +1306,7 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data): res = xp.asarray(x1, copy=True) res = iop(res, x2) # sanity check - ph.assert_result_shape(iop_name, [x1.shape, x2.shape], res.shape) + ph.assert_result_shape(iop_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape) good_example = False for l_idx, r_idx, o_idx in all_indices: @@ -1341,7 +1341,7 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data): def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get expected func = getattr(xp, func_name) out = func(xp.asarray([], dtype=dh.default_float)) - ph.assert_shape(func_name, out.shape, ()) # sanity check + ph.assert_shape(func_name, out_shape=out.shape, expected=()) # sanity check msg = f"{out=!r}, but should be {expected}" if math.isnan(expected): assert xp.isnan(out), msg @@ -1366,5 +1366,5 @@ def test_nan_propagation(func_name, x, data): out = func(x) - ph.assert_shape(func_name, out.shape, ()) # sanity check + ph.assert_shape(func_name, out_shape=out.shape, expected=()) # sanity check assert xp.isnan(out), f"{out=!r}, but should be NaN" diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 2d433dc6..b4c92590 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -36,13 +36,14 @@ def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: ) def test_max(x, data): kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") + keepdims = kw.get("keepdims", False) out = xp.max(x, **kw) - ph.assert_dtype("max", x.dtype, out.dtype) + ph.assert_dtype("max", in_dtype=x.dtype, out_dtype=out.dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "max", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "max", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw ) scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): @@ -52,7 +53,7 @@ def test_max(x, data): s = scalar_type(x[idx]) elements.append(s) expected = max(elements) - ph.assert_scalar_equals("max", scalar_type, out_idx, max_, expected) + ph.assert_scalar_equals("max", type_=scalar_type, idx=out_idx, out=max_, expected=expected) @given( @@ -65,13 +66,14 @@ def test_max(x, data): ) def test_mean(x, data): kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") + keepdims = kw.get("keepdims", False) out = xp.mean(x, **kw) - ph.assert_dtype("mean", x.dtype, out.dtype) + ph.assert_dtype("mean", in_dtype=x.dtype, out_dtype=out.dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "mean", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "mean", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw ) # Values testing mean is too finicky @@ -86,13 +88,14 @@ def test_mean(x, data): ) def test_min(x, data): kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") + keepdims = kw.get("keepdims", False) out = xp.min(x, **kw) - ph.assert_dtype("min", x.dtype, out.dtype) + ph.assert_dtype("min", in_dtype=x.dtype, out_dtype=out.dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "min", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "min", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw ) scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): @@ -102,7 +105,7 @@ def test_min(x, data): s = scalar_type(x[idx]) elements.append(s) expected = min(elements) - ph.assert_scalar_equals("min", scalar_type, out_idx, min_, expected) + ph.assert_scalar_equals("min", type_=scalar_type, idx=out_idx, out=min_, expected=expected) @given( @@ -122,6 +125,7 @@ def test_prod(x, data): ), label="kw", ) + keepdims = kw.get("keepdims", False) try: out = xp.prod(x, **kw) @@ -155,10 +159,10 @@ def test_prod(x, data): if _dtype in dh.uint_dtypes: assert dh.is_int_dtype(out.dtype) # sanity check else: - ph.assert_dtype("prod", x.dtype, out.dtype, _dtype) + ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=_dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "prod", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "prod", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw ) scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): @@ -172,7 +176,7 @@ def test_prod(x, data): if dh.is_int_dtype(out.dtype): m, M = dh.dtype_ranges[out.dtype] assume(m <= expected <= M) - ph.assert_scalar_equals("prod", scalar_type, out_idx, prod, expected) + ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx, out=prod, expected=expected) @given( @@ -191,21 +195,22 @@ def test_std(x, data): st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), label="correction", ) - keepdims = data.draw(st.booleans(), label="keepdims") + _keepdims = data.draw(st.booleans(), label="keepdims") kw = data.draw( hh.specified_kwargs( ("axis", axis, None), ("correction", correction, 0.0), - ("keepdims", keepdims, False), + ("keepdims", _keepdims, False), ), label="kw", ) + keepdims = kw.get("keepdims", False) out = xp.std(x, **kw) - ph.assert_dtype("std", x.dtype, out.dtype) + ph.assert_dtype("std", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_keepdimable_shape( - "std", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "std", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw ) # We can't easily test the result(s) as standard deviation methods vary a lot @@ -227,6 +232,7 @@ def test_sum(x, data): ), label="kw", ) + keepdims = kw.get("keepdims", False) try: out = xp.sum(x, **kw) @@ -260,10 +266,10 @@ def test_sum(x, data): if _dtype in dh.uint_dtypes: assert dh.is_int_dtype(out.dtype) # sanity check else: - ph.assert_dtype("sum", x.dtype, out.dtype, _dtype) + ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=_dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "sum", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "sum", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw ) scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): @@ -277,7 +283,7 @@ def test_sum(x, data): if dh.is_int_dtype(out.dtype): m, M = dh.dtype_ranges[out.dtype] assume(m <= expected <= M) - ph.assert_scalar_equals("sum", scalar_type, out_idx, sum_, expected) + ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected) @given( @@ -296,20 +302,21 @@ def test_var(x, data): st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), label="correction", ) - keepdims = data.draw(st.booleans(), label="keepdims") + _keepdims = data.draw(st.booleans(), label="keepdims") kw = data.draw( hh.specified_kwargs( ("axis", axis, None), ("correction", correction, 0.0), - ("keepdims", keepdims, False), + ("keepdims", _keepdims, False), ), label="kw", ) + keepdims = kw.get("keepdims", False) out = xp.var(x, **kw) - ph.assert_dtype("var", x.dtype, out.dtype) + ph.assert_dtype("var", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_keepdimable_shape( - "var", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "var", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw ) # We can't easily test the result(s) as variance methods vary a lot diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py index 7c09fb27..ead5c9d2 100644 --- a/array_api_tests/test_utility_functions.py +++ b/array_api_tests/test_utility_functions.py @@ -18,13 +18,14 @@ ) def test_all(x, data): kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") + keepdims = kw.get("keepdims", False) out = xp.all(x, **kw) - ph.assert_dtype("all", x.dtype, out.dtype, xp.bool) + ph.assert_dtype("all", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "all", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "all", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw ) scalar_type = dh.get_scalar_type(x.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): @@ -34,7 +35,8 @@ def test_all(x, data): s = scalar_type(x[idx]) elements.append(s) expected = all(elements) - ph.assert_scalar_equals("all", scalar_type, out_idx, result, expected) + ph.assert_scalar_equals("all", type_=scalar_type, idx=out_idx, + out=result, expected=expected, kw=kw) @given( @@ -43,13 +45,14 @@ def test_all(x, data): ) def test_any(x, data): kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") + keepdims = kw.get("keepdims", False) out = xp.any(x, **kw) - ph.assert_dtype("any", x.dtype, out.dtype, xp.bool) + ph.assert_dtype("any", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "any", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "any", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw, ) scalar_type = dh.get_scalar_type(x.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): @@ -59,4 +62,5 @@ def test_any(x, data): s = scalar_type(x[idx]) elements.append(s) expected = any(elements) - ph.assert_scalar_equals("any", scalar_type, out_idx, result, expected) + ph.assert_scalar_equals("any", type_=scalar_type, idx=out_idx, + out=result, expected=expected, kw=kw)