diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 3c09a1d5..7e7294f8 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -10,7 +10,7 @@ from hypothesis import assume, reject from hypothesis.strategies import (SearchStrategy, booleans, composite, floats, integers, complex_numbers, just, lists, none, one_of, - sampled_from, shared, builds, nothing) + sampled_from, shared, builds, nothing, permutations) from . import _array_module as xp, api_version from . import array_helpers as ah @@ -148,6 +148,13 @@ def mutually_promotable_dtypes( return one_of(strats).map(tuple) +@composite +def pair_of_mutually_promotable_dtypes(draw, max_size=2, *, dtypes=dh.all_dtypes): + sample = draw(mutually_promotable_dtypes( max_size, dtypes=dtypes)) + permuted = draw(permutations(sample)) + return sample, tuple(permuted) + + class OnewayPromotableDtypes(NamedTuple): input_dtype: DataType result_dtype: DataType diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index f9642f31..e844c432 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -208,7 +208,53 @@ def test_isdtype(dtype, kind): assert out == expected, f"{out=}, but should be {expected} [isdtype()]" -@given(hh.mutually_promotable_dtypes(None)) -def test_result_type(dtypes): - out = xp.result_type(*dtypes) - ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out") +@pytest.mark.min_version("2024.12") +class TestResultType: + @given(dtypes=hh.mutually_promotable_dtypes(None)) + def test_result_type(self, dtypes): + out = xp.result_type(*dtypes) + ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out") + + @given(pair=hh.pair_of_mutually_promotable_dtypes(None)) + def test_shuffled(self, pair): + """Test that result_type is insensitive to the order of arguments.""" + s1, s2 = pair + out1 = xp.result_type(*s1) + out2 = xp.result_type(*s2) + assert out1 == out2 + + @given(pair=hh.pair_of_mutually_promotable_dtypes(2), data=st.data()) + def test_arrays_and_dtypes(self, pair, data): + s1, s2 = pair + a2 = tuple(xp.empty(1, dtype=dt) for dt in s2) + a_and_dt = data.draw(st.permutations(s1 + a2)) + out = xp.result_type(*a_and_dt) + ph.assert_dtype("result_type", in_dtype=s1+s2, out_dtype=out, repr_name="out") + + @given(dtypes=hh.mutually_promotable_dtypes(2), data=st.data()) + def test_with_scalars(self, dtypes, data): + out = xp.result_type(*dtypes) + + if out == xp.bool: + scalars = [True] + elif out in dh.all_int_dtypes: + scalars = [1] + elif out in dh.real_dtypes: + scalars = [1, 1.0] + elif out in dh.numeric_dtypes: + scalars = [1, 1.0, 1j] # numeric_types - real_types == complex_types + else: + raise ValueError(f"unknown dtype {out = }.") + + scalar = data.draw(st.sampled_from(scalars)) + inputs = data.draw(st.permutations(dtypes + (scalar,))) + + out_scalar = xp.result_type(*inputs) + assert out_scalar == out + + # retry with arrays + arrays = tuple(xp.empty(1, dtype=dt) for dt in dtypes) + inputs = data.draw(st.permutations(arrays + (scalar,))) + out_scalar = xp.result_type(*inputs) + assert out_scalar == out +