diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 3cd819b4..588cfb1b 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -3,11 +3,16 @@ from hypothesis import strategies as st from .. import _array_module as xp -from .. import xps +from .. import dtype_helpers as dh from .. import shape_helpers as sh +from .. import xps from ..test_creation_functions import frange from ..test_manipulation_functions import roll_ndindex -from ..test_operators_and_elementwise_functions import mock_int_dtype +from ..test_operators_and_elementwise_functions import ( + mock_int_dtype, + oneway_broadcastable_shapes, + oneway_promotable_dtypes, +) from ..test_signatures import extension_module @@ -115,3 +120,13 @@ def test_int_to_dtype(x, dtype): except OverflowError: reject() assert mock_int_dtype(x, dtype) == d + + +@given(oneway_promotable_dtypes(dh.all_dtypes)) +def test_oneway_promotable_dtypes(D): + assert D.result_dtype == dh.result_type(*D) + + +@given(oneway_broadcastable_shapes()) +def test_oneway_broadcastable_shapes(S): + assert S.result_shape == sh.broadcast_shapes(*S) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 0eb15462..6947c061 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -30,6 +30,46 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: return xps.boolean_dtypes() | all_integer_dtypes() +class OnewayPromotableDtypes(NamedTuple): + input_dtype: DataType + result_dtype: DataType + + +@st.composite +def oneway_promotable_dtypes( + draw, dtypes: List[DataType] +) -> st.SearchStrategy[OnewayPromotableDtypes]: + """Return a strategy for input dtypes that promote to result dtypes.""" + d1, d2 = draw(hh.mutually_promotable_dtypes(dtypes=dtypes)) + result_dtype = dh.result_type(d1, d2) + if d1 == result_dtype: + return OnewayPromotableDtypes(d2, d1) + elif d2 == result_dtype: + return OnewayPromotableDtypes(d1, d2) + else: + reject() + + +class OnewayBroadcastableShapes(NamedTuple): + input_shape: Shape + result_shape: Shape + + +@st.composite +def oneway_broadcastable_shapes(draw) -> st.SearchStrategy[OnewayBroadcastableShapes]: + """Return a strategy for input shapes that broadcast to result shapes.""" + result_shape = draw(hh.shapes(min_side=1)) + input_shape = draw( + xps.broadcastable_shapes( + result_shape, + # Override defaults so bad shapes are less likely to be generated. + max_side=None if result_shape == () else max(result_shape), + max_dims=len(result_shape), + ).filter(lambda s: sh.broadcast_shapes(result_shape, s) == result_shape) + ) + return OnewayBroadcastableShapes(input_shape, result_shape) + + def mock_int_dtype(n: int, dtype: DataType) -> int: """Returns equivalent of `n` that mocks `dtype` behaviour.""" nbits = dh.dtype_nbits[dtype] @@ -306,8 +346,14 @@ def __repr__(self): def make_binary_params( - elwise_func_name: str, dtypes_strat: st.SearchStrategy[DataType] + elwise_func_name: str, dtypes: List[DataType] ) -> List[Param[BinaryParamContext]]: + if hh.FILTER_UNDEFINED_DTYPES: + dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)] + shared_oneway_dtypes = st.shared(oneway_promotable_dtypes(dtypes)) + left_dtypes = shared_oneway_dtypes.map(lambda D: D.result_dtype) + right_dtypes = shared_oneway_dtypes.map(lambda D: D.input_dtype) + def make_param( func_name: str, func_type: FuncType, right_is_scalar: bool ) -> Param[BinaryParamContext]: @@ -318,26 +364,29 @@ def make_param( left_sym = "x1" right_sym = "x2" - shared_dtypes = st.shared(dtypes_strat) if right_is_scalar: - left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes(**shapes_kw)) - right_strat = shared_dtypes.flatmap( - lambda d: xps.from_dtype(d, **finite_kw) - ) + left_strat = xps.arrays(dtype=left_dtypes, shape=hh.shapes(**shapes_kw)) + right_strat = right_dtypes.flatmap(lambda d: xps.from_dtype(d, **finite_kw)) else: if func_type is FuncType.IOP: - shared_shapes = st.shared(hh.shapes(**shapes_kw)) - left_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes) - right_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes) + shared_oneway_shapes = st.shared(oneway_broadcastable_shapes()) + left_strat = xps.arrays( + dtype=left_dtypes, + shape=shared_oneway_shapes.map(lambda S: S.result_shape), + ) + right_strat = xps.arrays( + dtype=right_dtypes, + shape=shared_oneway_shapes.map(lambda S: S.input_shape), + ) else: mutual_shapes = st.shared( hh.mutually_broadcastable_shapes(2, **shapes_kw) ) left_strat = xps.arrays( - dtype=shared_dtypes, shape=mutual_shapes.map(lambda pair: pair[0]) + dtype=left_dtypes, shape=mutual_shapes.map(lambda pair: pair[0]) ) right_strat = xps.arrays( - dtype=shared_dtypes, shape=mutual_shapes.map(lambda pair: pair[1]) + dtype=right_dtypes, shape=mutual_shapes.map(lambda pair: pair[1]) ) if func_type is FuncType.FUNC: @@ -514,7 +563,7 @@ def test_acosh(x): ) -@pytest.mark.parametrize("ctx,", make_binary_params("add", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx,", make_binary_params("add", dh.numeric_dtypes)) @given(data=st.data()) def test_add(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -579,7 +628,7 @@ def test_atanh(x): @pytest.mark.parametrize( - "ctx", make_binary_params("bitwise_and", boolean_and_all_integer_dtypes()) + "ctx", make_binary_params("bitwise_and", dh.bool_and_all_int_dtypes) ) @given(data=st.data()) def test_bitwise_and(ctx, data): @@ -598,7 +647,7 @@ def test_bitwise_and(ctx, data): @pytest.mark.parametrize( - "ctx", make_binary_params("bitwise_left_shift", all_integer_dtypes()) + "ctx", make_binary_params("bitwise_left_shift", dh.all_int_dtypes) ) @given(data=st.data()) def test_bitwise_left_shift(ctx, data): @@ -638,7 +687,7 @@ def test_bitwise_invert(ctx, data): @pytest.mark.parametrize( - "ctx", make_binary_params("bitwise_or", boolean_and_all_integer_dtypes()) + "ctx", make_binary_params("bitwise_or", dh.bool_and_all_int_dtypes) ) @given(data=st.data()) def test_bitwise_or(ctx, data): @@ -657,7 +706,7 @@ def test_bitwise_or(ctx, data): @pytest.mark.parametrize( - "ctx", make_binary_params("bitwise_right_shift", all_integer_dtypes()) + "ctx", make_binary_params("bitwise_right_shift", dh.all_int_dtypes) ) @given(data=st.data()) def test_bitwise_right_shift(ctx, data): @@ -678,7 +727,7 @@ def test_bitwise_right_shift(ctx, data): @pytest.mark.parametrize( - "ctx", make_binary_params("bitwise_xor", boolean_and_all_integer_dtypes()) + "ctx", make_binary_params("bitwise_xor", dh.bool_and_all_int_dtypes) ) @given(data=st.data()) def test_bitwise_xor(ctx, data): @@ -720,7 +769,7 @@ def test_cosh(x): unary_assert_against_refimpl("cosh", x, out, math.cosh) -@pytest.mark.parametrize("ctx", make_binary_params("divide", xps.floating_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.float_dtypes)) @given(data=st.data()) def test_divide(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -743,7 +792,7 @@ def test_divide(ctx, data): ) -@pytest.mark.parametrize("ctx", make_binary_params("equal", xps.scalar_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("equal", dh.all_dtypes)) @given(data=st.data()) def test_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -795,9 +844,7 @@ def test_floor(x): unary_assert_against_refimpl("floor", x, out, math.floor, strict_check=True) -@pytest.mark.parametrize( - "ctx", make_binary_params("floor_divide", xps.numeric_dtypes()) -) +@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.numeric_dtypes)) @given(data=st.data()) def test_floor_divide(ctx, data): left = data.draw( @@ -816,7 +863,7 @@ def test_floor_divide(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv) -@pytest.mark.parametrize("ctx", make_binary_params("greater", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("greater", dh.numeric_dtypes)) @given(data=st.data()) def test_greater(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -836,9 +883,7 @@ def test_greater(ctx, data): ) -@pytest.mark.parametrize( - "ctx", make_binary_params("greater_equal", xps.numeric_dtypes()) -) +@pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.numeric_dtypes)) @given(data=st.data()) def test_greater_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -882,7 +927,7 @@ def test_isnan(x): unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool) -@pytest.mark.parametrize("ctx", make_binary_params("less", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("less", dh.numeric_dtypes)) @given(data=st.data()) def test_less(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -902,7 +947,7 @@ def test_less(ctx, data): ) -@pytest.mark.parametrize("ctx", make_binary_params("less_equal", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.numeric_dtypes)) @given(data=st.data()) def test_less_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1014,7 +1059,7 @@ def test_logical_xor(x1, x2): ) -@pytest.mark.parametrize("ctx", make_binary_params("multiply", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes)) @given(data=st.data()) def test_multiply(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1047,7 +1092,7 @@ def test_negative(ctx, data): ) -@pytest.mark.parametrize("ctx", make_binary_params("not_equal", xps.scalar_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("not_equal", dh.all_dtypes)) @given(data=st.data()) def test_not_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1079,7 +1124,7 @@ def test_positive(ctx, data): ph.assert_array(ctx.func_name, out, x) -@pytest.mark.parametrize("ctx", make_binary_params("pow", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes)) @given(data=st.data()) def test_pow(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1103,7 +1148,7 @@ def test_pow(ctx, data): ) -@pytest.mark.parametrize("ctx", make_binary_params("remainder", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.numeric_dtypes)) @given(data=st.data()) def test_remainder(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1174,7 +1219,7 @@ def test_sqrt(x): ) -@pytest.mark.parametrize("ctx", make_binary_params("subtract", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_binary_params("subtract", dh.numeric_dtypes)) @given(data=st.data()) def test_subtract(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym)