diff --git a/README.md b/README.md index 436cfb89..1d4ad770 100644 --- a/README.md +++ b/README.md @@ -157,6 +157,11 @@ Use the `--ci` flag to run only the primary and special cases tests. You can ignore the other test cases as they are redundant for the purposes of checking compliance. +#### Data-dependent shapes + +Use the `--disable-data-dependent-shapes` flag to skip testing functions which have +[data-dependent shapes](https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html). + #### Extensions By default, tests for the optional Array API extensions such as @@ -200,16 +205,10 @@ instead of having a seperate `skips.txt` file, e.g.: # Skip test cases with known issues cat << EOF >> skips.txt - # Skip specific test case, e.g. when argsort() does not respect relative order - # https://github.com/numpy/numpy/issues/20778 + # Comments can still work here array_api_tests/test_sorting_functions.py::test_argsort - - # Skip specific test case parameter, e.g. you forgot to implement in-place adds array_api_tests/test_add[__iadd__(x1, x2)] array_api_tests/test_add[__iadd__(x, s)] - - # Skip module, e.g. when your set functions treat NaNs as non-distinct - # https://github.com/numpy/numpy/issues/20326 array_api_tests/test_set_functions.py EOF diff --git a/array_api_tests/meta/test_pytest_helpers.py b/array_api_tests/meta/test_pytest_helpers.py index 21da2264..117e2b11 100644 --- a/array_api_tests/meta/test_pytest_helpers.py +++ b/array_api_tests/meta/test_pytest_helpers.py @@ -1,7 +1,7 @@ from pytest import raises -from .. import pytest_helpers as ph from .. import _array_module as xp +from .. import pytest_helpers as ph def test_assert_dtype(): @@ -11,3 +11,12 @@ def test_assert_dtype(): ph.assert_dtype("bool_func", [xp.uint8, xp.int8], xp.bool, xp.bool) ph.assert_dtype("single_promoted_func", [xp.uint8], xp.uint8) ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool) + + +def test_assert_array(): + ph.assert_array("int zeros", xp.asarray(0), xp.asarray(0)) + ph.assert_array("pos zeros", xp.asarray(0.0), xp.asarray(0.0)) + with raises(AssertionError): + ph.assert_array("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0)) + with raises(AssertionError): + ph.assert_array("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0)) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 9a5ffbb2..989b486f 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -14,6 +14,8 @@ "doesnt_raise", "nargs", "fmt_kw", + "is_pos_zero", + "is_neg_zero", "assert_dtype", "assert_kw_dtype", "assert_default_float", @@ -22,6 +24,7 @@ "assert_shape", "assert_result_shape", "assert_keepdimable_shape", + "assert_0d_equals", "assert_fill", "assert_array", ] @@ -69,6 +72,14 @@ def fmt_kw(kw: Dict[str, Any]) -> str: return ", ".join(f"{k}={v}" for k, v in kw.items()) +def is_pos_zero(n: float) -> bool: + return n == 0 and math.copysign(1, n) == 1 + + +def is_neg_zero(n: float) -> bool: + return n == 0 and math.copysign(1, n) == -1 + + def assert_dtype( func_name: str, in_dtype: Union[DataType, Sequence[DataType]], @@ -232,15 +243,28 @@ def assert_fill( def assert_array(func_name: str, out: Array, expected: Array, /, **kw): assert_dtype(func_name, out.dtype, expected.dtype) assert_shape(func_name, out.shape, expected.shape, **kw) - msg = f"out not as expected [{func_name}({fmt_kw(kw)})]\n{out=}\n{expected=}" + f_func = f"[{func_name}({fmt_kw(kw)})]" if dh.is_float_dtype(out.dtype): - neg_zeros = expected == -0.0 - assert xp.all((out == -0.0) == neg_zeros), msg - pos_zeros = expected == +0.0 - assert xp.all((out == +0.0) == pos_zeros), msg - nans = xp.isnan(expected) - assert xp.all(xp.isnan(out) == nans), msg - mask = ~(neg_zeros | pos_zeros | nans) - assert xp.all(out[mask] == expected[mask]), msg + for idx in sh.ndindex(out.shape): + at_out = out[idx] + at_expected = expected[idx] + msg = ( + f"{sh.fmt_idx('out', idx)}={at_out}, should be {at_expected} " + f"{f_func}" + ) + if xp.isnan(at_expected): + assert xp.isnan(at_out), msg + elif at_expected == 0.0 or at_expected == -0.0: + scalar_at_expected = float(at_expected) + scalar_at_out = float(at_out) + if is_pos_zero(scalar_at_expected): + assert is_pos_zero(scalar_at_out), msg + else: + assert is_neg_zero(scalar_at_expected) # sanity check + assert is_neg_zero(scalar_at_out), msg + else: + assert at_out == at_expected, msg else: - assert xp.all(out == expected), msg + assert xp.all(out == expected), ( + f"out not as expected {f_func}\n" f"{out=}\n{expected=}" + ) diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index a81339d0..583eda76 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -280,7 +280,7 @@ def test_asarray_arrays(x, data): if copy: assert not xp.all( out == x - ), "xp.all(out == x)=True, but should be False after x was mutated\n{out=}" + ), f"xp.all(out == x)=True, but should be False after x was mutated\n{out=}" elif copy is False: pass # TODO diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 4e131050..2c9da2b9 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -123,7 +123,10 @@ def default_filter(s: Scalar) -> bool: Used by default as these values are typically special-cased. """ - return math.isfinite(s) and s is not -0.0 and s is not +0.0 + if isinstance(s, int): # note bools are ints + return True + else: + return math.isfinite(s) and s != 0 T = TypeVar("T") @@ -538,7 +541,7 @@ def test_abs(ctx, data): abs, # type: ignore expr_template="abs({})={}", filter_=lambda s: ( - s == float("infinity") or (math.isfinite(s) and s is not -0.0) + s == float("infinity") or (math.isfinite(s) and not ph.is_neg_zero(s)) ), ) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 01c26d0c..24325685 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -21,13 +21,10 @@ data=st.data(), ) def test_argmax(x, data): - kw = data.draw( - hh.kwargs( - axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)), - keepdims=st.booleans(), - ), - label="kw", - ) + axis_strat = st.none() + if x.ndim > 0: + axis_strat |= st.integers(-x.ndim, max(x.ndim - 1, 0)) + kw = data.draw(hh.kwargs(axis=axis_strat, keepdims=st.booleans()), label="kw") out = xp.argmax(x, **kw) @@ -56,13 +53,10 @@ def test_argmax(x, data): data=st.data(), ) def test_argmin(x, data): - kw = data.draw( - hh.kwargs( - axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)), - keepdims=st.booleans(), - ), - label="kw", - ) + axis_strat = st.none() + if x.ndim > 0: + axis_strat |= st.integers(-x.ndim, max(x.ndim - 1, 0)) + kw = data.draw(hh.kwargs(axis=axis_strat, keepdims=st.booleans()), label="kw") out = xp.argmin(x, **kw) @@ -82,7 +76,7 @@ def test_argmin(x, data): ph.assert_scalar_equals("argmin", int, out_idx, min_i, expected) -# TODO: skip if opted out +@pytest.mark.data_dependent_shapes @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) def test_nonzero(x): out = xp.nonzero(x) diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index 5ceceb54..5bae6147 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -12,7 +12,7 @@ from . import shape_helpers as sh from . import xps -pytestmark = pytest.mark.ci +pytestmark = [pytest.mark.ci, pytest.mark.data_dependent_shapes] @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) diff --git a/conftest.py b/conftest.py index 2af3fef1..9fec536b 100644 --- a/conftest.py +++ b/conftest.py @@ -35,6 +35,13 @@ def pytest_addoption(parser): default=[], help="disable testing for Array API extension(s)", ) + # data-dependent shape + parser.addoption( + "--disable-data-dependent-shapes", + "--disable-dds", + action="store_true", + help="disable testing functions with output shapes dependent on input", + ) # CI parser.addoption( "--ci", @@ -47,6 +54,9 @@ def pytest_configure(config): config.addinivalue_line( "markers", "xp_extension(ext): tests an Array API extension" ) + config.addinivalue_line( + "markers", "data_dependent_shapes: output shapes are dependent on inputs" + ) config.addinivalue_line("markers", "ci: primary test") # Hypothesis hypothesis_max_examples = config.getoption("--hypothesis-max-examples") @@ -83,9 +93,15 @@ def xp_has_ext(ext: str) -> bool: def pytest_collection_modifyitems(config, items): disabled_exts = config.getoption("--disable-extension") + disabled_dds = config.getoption("--disable-data-dependent-shapes") ci = config.getoption("--ci") for item in items: markers = list(item.iter_markers()) + # skip if specified in skips.txt + for id_ in skip_ids: + if item.nodeid.startswith(id_): + item.add_marker(mark.skip(reason="skips.txt")) + break # skip if disabled or non-existent extension ext_mark = next((m for m in markers if m.name == "xp_extension"), None) if ext_mark is not None: @@ -96,11 +112,14 @@ def pytest_collection_modifyitems(config, items): ) elif not xp_has_ext(ext): item.add_marker(mark.skip(reason=f"{ext} not found in array module")) - # skip if specified in skips.txt - for id_ in skip_ids: - if item.nodeid.startswith(id_): - item.add_marker(mark.skip(reason="skips.txt")) - break + # skip if disabled by dds flag + if disabled_dds: + for m in markers: + if m.name == "data_dependent_shapes": + item.add_marker( + mark.skip(reason="disabled via --disable-data-dependent-shapes") + ) + break # skip if test not appropiate for CI if ci: ci_mark = next((m for m in markers if m.name == "ci"), None)