diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 1f144c72..1a7f6516 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -3,6 +3,7 @@ from itertools import count from typing import Iterator, NamedTuple, Union +import pytest from hypothesis import assume, given, note from hypothesis import strategies as st @@ -76,6 +77,7 @@ def reals(min_value=None, max_value=None) -> st.SearchStrategy[Union[int, float] ) +@pytest.mark.has_setup_funcs @given(dtype=st.none() | hh.real_dtypes, data=st.data()) def test_arange(dtype, data): if dtype is None or dh.is_float_dtype(dtype): @@ -194,6 +196,7 @@ def test_arange(dtype, data): ), f"out[0]={out[0]}, but should be {_start} {f_func}" +@pytest.mark.has_setup_funcs @given(shape=hh.shapes(min_side=1), data=st.data()) def test_asarray_scalars(shape, data): kw = data.draw( @@ -257,6 +260,7 @@ def scalar_eq(s1: Scalar, s2: Scalar) -> bool: return s1 == s2 +@pytest.mark.has_setup_funcs @given( shape=hh.shapes(), dtypes=hh.oneway_promotable_dtypes(dh.all_dtypes), @@ -424,6 +428,7 @@ def test_full(shape, fill_value, kw): ph.assert_fill("full", fill_value=fill_value, dtype=dtype, out=out, kw=dict(fill_value=fill_value)) +@pytest.mark.has_setup_funcs @given(kw=hh.kwargs(dtype=st.none() | hh.all_dtypes), data=st.data()) def test_full_like(kw, data): dtype = kw.get("dtype", None) or data.draw(hh.all_dtypes, label="dtype") @@ -442,6 +447,7 @@ def test_full_like(kw, data): finite_kw = {"allow_nan": False, "allow_infinity": False} +@pytest.mark.has_setup_funcs @given( num=hh.sizes, dtype=st.none() | hh.real_floating_dtypes, @@ -492,6 +498,7 @@ def test_linspace(num, dtype, endpoint, data): ph.assert_array_elements("linspace", out=out, expected=expected) +@pytest.mark.has_setup_funcs @given(dtype=hh.numeric_dtypes, data=st.data()) def test_meshgrid(dtype, data): # The number and size of generated arrays is arbitrarily limited to prevent diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 754b507d..f0021fc9 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -347,6 +347,7 @@ def test_repeat(x, kw, data): reshape_shape = st.shared(hh.shapes(), key="reshape_shape") +@pytest.mark.has_setup_funcs @pytest.mark.unvectorized @given( x=hh.arrays(dtype=hh.all_dtypes, shape=reshape_shape), diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 7fefa151..a899d053 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -781,6 +781,7 @@ def test_acosh(x): ) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize("ctx,", make_binary_params("add", dh.numeric_dtypes)) @given(data=st.data()) def test_add(ctx, data): @@ -854,6 +855,7 @@ def test_atanh(x): ) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize( "ctx", make_binary_params("bitwise_and", dh.bool_and_all_int_dtypes) ) @@ -873,6 +875,7 @@ def test_bitwise_and(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "&", refimpl) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize( "ctx", make_binary_params("bitwise_left_shift", dh.all_int_dtypes) ) @@ -895,6 +898,7 @@ def test_bitwise_left_shift(ctx, data): ) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize( "ctx", make_unary_params("bitwise_invert", dh.bool_and_all_int_dtypes) ) @@ -913,6 +917,7 @@ def test_bitwise_invert(ctx, data): unary_assert_against_refimpl(ctx.func_name, x, out, refimpl, expr_template="~{}={}") +@pytest.mark.has_setup_funcs @pytest.mark.parametrize( "ctx", make_binary_params("bitwise_or", dh.bool_and_all_int_dtypes) ) @@ -932,6 +937,7 @@ def test_bitwise_or(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "|", refimpl) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize( "ctx", make_binary_params("bitwise_right_shift", dh.all_int_dtypes) ) @@ -953,6 +959,7 @@ def test_bitwise_right_shift(ctx, data): ) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize( "ctx", make_binary_params("bitwise_xor", dh.bool_and_all_int_dtypes) ) @@ -981,6 +988,7 @@ def test_ceil(x): @pytest.mark.min_version("2023.12") +@pytest.mark.has_setup_funcs @given(x=hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes()), data=st.data()) def test_clip(x, data): # Ensure that if both min and max are arrays that all three of x, min, max @@ -1145,6 +1153,7 @@ def test_cosh(x): unary_assert_against_refimpl("cosh", x, out, refimpl) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize("ctx", make_binary_params("divide", dh.all_float_dtypes)) @given(data=st.data()) def test_divide(ctx, data): @@ -1168,6 +1177,7 @@ def test_divide(ctx, data): ) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize("ctx", make_binary_params("equal", dh.all_dtypes)) @given(data=st.data()) def test_equal(ctx, data): @@ -1242,6 +1252,7 @@ def refimpl(z): unary_assert_against_refimpl("floor", x, out, refimpl, strict_check=True) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.real_dtypes)) @given(data=st.data()) def test_floor_divide(ctx, data): @@ -1261,6 +1272,7 @@ def test_floor_divide(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize("ctx", make_binary_params("greater", dh.real_dtypes)) @given(data=st.data()) def test_greater(ctx, data): @@ -1281,6 +1293,7 @@ def test_greater(ctx, data): ) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.real_dtypes)) @given(data=st.data()) def test_greater_equal(ctx, data): @@ -1352,6 +1365,7 @@ def test_isnan(x): unary_assert_against_refimpl("isnan", x, out, refimpl, res_stype=bool) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize("ctx", make_binary_params("less", dh.real_dtypes)) @given(data=st.data()) def test_less(ctx, data): @@ -1372,6 +1386,7 @@ def test_less(ctx, data): ) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.real_dtypes)) @given(data=st.data()) def test_less_equal(ctx, data): @@ -1463,6 +1478,7 @@ def logaddexp_refimpl(l: float, r: float) -> float: @pytest.mark.min_version("2023.12") +@pytest.mark.has_setup_funcs @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_logaddexp(x1, x2): out = xp.logaddexp(x1, x2) @@ -1476,6 +1492,7 @@ def test_logaddexp(x1, x2): ) +@pytest.mark.has_setup_funcs @given(hh.arrays(dtype=xp.bool, shape=hh.shapes())) def test_logical_not(x): out = xp.logical_not(x) @@ -1486,6 +1503,7 @@ def test_logical_not(x): ) +@pytest.mark.has_setup_funcs @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_and(x1, x2): out = xp.logical_and(x1, x2) @@ -1500,6 +1518,7 @@ def test_logical_and(x1, x2): ) +@pytest.mark.has_setup_funcs @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_or(x1, x2): out = xp.logical_or(x1, x2) @@ -1514,6 +1533,7 @@ def test_logical_or(x1, x2): ) +@pytest.mark.has_setup_funcs @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_xor(x1, x2): out = xp.logical_xor(x1, x2) @@ -1546,6 +1566,7 @@ def test_minimum(x1, x2): ) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes)) @given(data=st.data()) def test_multiply(ctx, data): @@ -1577,6 +1598,7 @@ def test_negative(ctx, data): ) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize("ctx", make_binary_params("not_equal", dh.all_dtypes)) @given(data=st.data()) def test_not_equal(ctx, data): @@ -1598,6 +1620,7 @@ def test_not_equal(ctx, data): @pytest.mark.min_version("2024.12") +@pytest.mark.has_setup_funcs @given( shapes=hh.two_mutually_broadcastable_shapes, dtype=hh.real_floating_dtypes, @@ -1617,6 +1640,8 @@ def test_nextafter(shapes, dtype, data): out=out ) + +@pytest.mark.has_setup_funcs @pytest.mark.parametrize("ctx", make_unary_params("positive", dh.numeric_dtypes)) @given(data=st.data()) def test_positive(ctx, data): @@ -1629,6 +1654,7 @@ def test_positive(ctx, data): ph.assert_array_elements(ctx.func_name, out=out, expected=x) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes)) @given(data=st.data()) def test_pow(ctx, data): @@ -1676,6 +1702,7 @@ def test_reciprocal(x): @pytest.mark.skip(reason="flaky") +@pytest.mark.has_setup_funcs @pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.real_dtypes)) @given(data=st.data()) def test_remainder(ctx, data): @@ -1770,6 +1797,7 @@ def test_sqrt(x): ) +@pytest.mark.has_setup_funcs @pytest.mark.parametrize("ctx", make_binary_params("subtract", dh.numeric_dtypes)) @given(data=st.data()) def test_subtract(ctx, data): @@ -1923,6 +1951,7 @@ def test_binary_with_scalars_bitwise_shifts(func_data, x1x2): @pytest.mark.min_version("2024.12") +@pytest.mark.has_setup_funcs @pytest.mark.unvectorized @given( x1x2=hh.array_and_py_scalar([xp.int32]), diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py index b6e0a4fe..38ecf450 100644 --- a/array_api_tests/test_utility_functions.py +++ b/array_api_tests/test_utility_functions.py @@ -67,6 +67,7 @@ def test_any(x, data): @pytest.mark.unvectorized @pytest.mark.min_version("2024.12") +@pytest.mark.has_setup_funcs @given( x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)), data=st.data(), diff --git a/conftest.py b/conftest.py index 05baebc1..fadae116 100644 --- a/conftest.py +++ b/conftest.py @@ -98,6 +98,11 @@ def pytest_configure(config): "markers", "unvectorized: asserts against values via element-wise iteration (not performative!)", ) + config.addinivalue_line( + "markers", + "has_setup_funcs: run when essential draw data setup functions used " + "by Hypothesis are available in the namespace", + ) # Hypothesis deadline = None if config.getoption("--hypothesis-disable-deadline") else 800 settings.register_profile( @@ -202,6 +207,9 @@ def pytest_collection_modifyitems(config, items): # ------------------------------------------------------ xfail_mark = get_xfail_mark() + + essential_funcs = ["asarray", "isnan", "reshape", "zeros"] + HAS_ESSENTIAL_FUNCS = all(hasattr(xp, func_name) for func_name in essential_funcs) for item in items: markers = list(item.iter_markers()) @@ -245,6 +253,13 @@ def pytest_collection_modifyitems(config, items): reason=f"requires ARRAY_API_TESTS_VERSION >= {min_version}" ) ) + # skip if namespace doesn't support essential draw data setup functions + if any(m.name == "has_setup_funcs" for m in markers) and not HAS_ESSENTIAL_FUNCS: + item.add_marker( + mark.skip(reason="At least one of the essential data setup " + "functions is not present in the namespace: " + f"{essential_funcs}") + ) # reduce max generated Hypothesis example for unvectorized tests if any(m.name == "unvectorized" for m in markers): # TODO: limit generated examples when settings already applied