diff --git a/array-api b/array-api index ab69aa24..ea6a47f0 160000 --- a/array-api +++ b/array-api @@ -1 +1 @@ -Subproject commit ab69aa240025ff1d52525ce3859b69ebfd6b7faf +Subproject commit ea6a47f03e0aa26a9b17e70deba12e4096cfd2f3 diff --git a/array_api_tests/meta/test_pytest_helpers.py b/array_api_tests/meta/test_pytest_helpers.py index 17ed5534..a32c6f33 100644 --- a/array_api_tests/meta/test_pytest_helpers.py +++ b/array_api_tests/meta/test_pytest_helpers.py @@ -20,3 +20,7 @@ def test_assert_array_elements(): ph.assert_array_elements("mixed sign zeros", out=xp.asarray(0.0), expected=xp.asarray(-0.0)) with raises(AssertionError): ph.assert_array_elements("mixed sign zeros", out=xp.asarray(-0.0), expected=xp.asarray(0.0)) + + ph.assert_array_elements("nans", out=xp.asarray(float("nan")), expected=xp.asarray(float("nan"))) + with raises(AssertionError): + ph.assert_array_elements("nan and zero", out=xp.asarray(float("nan")), expected=xp.asarray(0.0)) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index e6ede7b2..ead9fc6e 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -7,6 +7,7 @@ from . import dtype_helpers as dh from . import shape_helpers as sh from . import stubs +from . import xp as _xp from .typing import Array, DataType, Scalar, ScalarType, Shape __all__ = [ @@ -420,6 +421,35 @@ def assert_fill( assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg +def _real_float_strict_equals(out: Array, expected: Array) -> bool: + nan_mask = xp.isnan(out) + if not xp.all(nan_mask == xp.isnan(expected)): + return False + ignore_mask = nan_mask + + # Test sign of zeroes if xp.signbit() available, otherwise ignore as it's + # not that big of a deal for the perf costs. + if hasattr(_xp, "signbit"): + out_zero_mask = out == 0 + out_sign_mask = _xp.signbit(out) + out_pos_zero_mask = out_zero_mask & out_sign_mask + out_neg_zero_mask = out_zero_mask & ~out_sign_mask + expected_zero_mask = expected == 0 + expected_sign_mask = _xp.signbit(expected) + expected_pos_zero_mask = expected_zero_mask & expected_sign_mask + expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask + pos_zero_match = out_pos_zero_mask == expected_pos_zero_mask + neg_zero_match = out_neg_zero_mask == expected_neg_zero_mask + if not (xp.all(pos_zero_match) and xp.all(neg_zero_match)): + return False + ignore_mask |= out_zero_mask + + replacement = xp.asarray(42, dtype=out.dtype) # i.e. an arbitrary non-zero value that equals itself + assert replacement == replacement # sanity check + match = xp.where(ignore_mask, replacement, out) == xp.where(ignore_mask, replacement, expected) + return xp.all(match) + + def _assert_float_element(at_out: Array, at_expected: Array, msg: str): if xp.isnan(at_expected): assert xp.isnan(at_out), msg @@ -455,31 +485,45 @@ def assert_array_elements( >>> assert xp.all(out == x) """ - __tracebackhide__ = True + # __tracebackhide__ = True dh.result_type(out.dtype, expected.dtype) # sanity check assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check f_func = f"[{func_name}({fmt_kw(kw)})]" + + # First we try short-circuit for a successful assertion by using vectorised checks. + if out.dtype in dh.real_float_dtypes: + if _real_float_strict_equals(out, expected): + return + elif out.dtype in dh.complex_dtypes: + real_match = _real_float_strict_equals(xp.real(out), xp.real(expected)) + imag_match = _real_float_strict_equals(xp.imag(out), xp.imag(expected)) + if real_match and imag_match: + return + else: + match = out == expected + if xp.all(match): + return + + # In case of mismatch, generate a more helpful error. Cycling through all indices is + # costly in some array api implementations, so we only do this in the case of a failure. + msg_template = "{}={}, but should be {} " + f_func if out.dtype in dh.real_float_dtypes: for idx in sh.ndindex(out.shape): at_out = out[idx] at_expected = expected[idx] - msg = ( - f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " - f"{f_func}" - ) + msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) _assert_float_element(at_out, at_expected, msg) elif out.dtype in dh.complex_dtypes: assert (out.dtype in dh.complex_dtypes) == (expected.dtype in dh.complex_dtypes) for idx in sh.ndindex(out.shape): at_out = out[idx] at_expected = expected[idx] - msg = ( - f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " - f"{f_func}" - ) + msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) _assert_float_element(xp.real(at_out), xp.real(at_expected), msg) _assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg) else: - assert xp.all( - out == expected - ), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}" + for idx in sh.ndindex(out.shape): + at_out = out[idx] + at_expected = expected[idx] + msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) + assert at_out == at_expected, msg diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 19e945ca..ec2df060 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -354,14 +354,14 @@ def test_eye(n_rows, n_cols, kw): ph.assert_kw_dtype("eye", kw_dtype=kw["dtype"], out_dtype=out.dtype) _n_cols = n_rows if n_cols is None else n_cols ph.assert_shape("eye", out_shape=out.shape, expected=(n_rows, _n_cols), kw=dict(n_rows=n_rows, n_cols=n_cols)) - f_func = f"[eye({n_rows=}, {n_cols=})]" - for i in range(n_rows): - for j in range(_n_cols): - f_indexed_out = f"out[{i}, {j}]={out[i, j]}" - if j - i == kw.get("k", 0): - assert out[i, j] == 1, f"{f_indexed_out}, should be 1 {f_func}" - else: - assert out[i, j] == 0, f"{f_indexed_out}, should be 0 {f_func}" + k = kw.get("k", 0) + expected = xp.asarray( + [[1 if j - i == k else 0 for j in range(_n_cols)] for i in range(n_rows)], + dtype=out.dtype # Note: dtype already checked above. + ) + if expected.size == 0: + expected = xp.reshape(expected, (n_rows, _n_cols)) + ph.assert_array_elements("eye", out=out, expected=expected, kw=kw) default_unsafe_dtypes = [xp.uint64]