Skip to content

Commit c8303d9

Browse files
committed
Generate oneway promotable dtypes for elwise/op tests
1 parent 7754038 commit c8303d9

File tree

2 files changed

+59
-33
lines changed

2 files changed

+59
-33
lines changed

array_api_tests/meta/test_utils.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
from hypothesis import strategies as st
44

55
from .. import _array_module as xp
6+
from .. import dtype_helpers as dh
67
from .. import shape_helpers as sh
78
from .. import xps
89
from ..test_creation_functions import frange
910
from ..test_manipulation_functions import roll_ndindex
1011
from ..test_operators_and_elementwise_functions import (
1112
mock_int_dtype,
1213
oneway_broadcastable_shapes,
14+
oneway_promotable_dtypes,
1315
)
1416
from ..test_signatures import extension_module
1517

@@ -120,6 +122,11 @@ def test_int_to_dtype(x, dtype):
120122
assert mock_int_dtype(x, dtype) == d
121123

122124

125+
@given(oneway_promotable_dtypes(dh.all_dtypes))
126+
def test_oneway_promotable_dtypes(D):
127+
assert D.result_dtype == dh.result_type(*D)
128+
129+
123130
@given(oneway_broadcastable_shapes())
124131
def test_oneway_broadcastable_shapes(S):
125-
assert sh.broadcast_shapes(*S) == S.result_shape
132+
assert S.result_shape == sh.broadcast_shapes(*S)

array_api_tests/test_operators_and_elementwise_functions.py

+51-32
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,26 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
3030
return xps.boolean_dtypes() | all_integer_dtypes()
3131

3232

33+
class OnewayPromotableDtypes(NamedTuple):
34+
input_dtype: DataType
35+
result_dtype: DataType
36+
37+
38+
@st.composite
39+
def oneway_promotable_dtypes(
40+
draw, dtypes: List[DataType]
41+
) -> st.SearchStrategy[OnewayPromotableDtypes]:
42+
"""Return a strategy for input dtypes that promote to result dtypes."""
43+
d1, d2 = draw(hh.mutually_promotable_dtypes(dtypes=dtypes))
44+
result_dtype = dh.result_type(d1, d2)
45+
if d1 == result_dtype:
46+
return OnewayPromotableDtypes(d2, d1)
47+
elif d2 == result_dtype:
48+
return OnewayPromotableDtypes(d1, d2)
49+
else:
50+
reject()
51+
52+
3353
class OnewayBroadcastableShapes(NamedTuple):
3454
input_shape: Shape
3555
result_shape: Shape
@@ -326,8 +346,14 @@ def __repr__(self):
326346

327347

328348
def make_binary_params(
329-
elwise_func_name: str, dtypes_strat: st.SearchStrategy[DataType]
349+
elwise_func_name: str, dtypes: List[DataType]
330350
) -> List[Param[BinaryParamContext]]:
351+
if hh.FILTER_UNDEFINED_DTYPES:
352+
dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)]
353+
shared_oneway_dtypes = st.shared(oneway_promotable_dtypes(dtypes))
354+
left_dtypes = shared_oneway_dtypes.map(lambda D: D.result_dtype)
355+
right_dtypes = shared_oneway_dtypes.map(lambda D: D.input_dtype)
356+
331357
def make_param(
332358
func_name: str, func_type: FuncType, right_is_scalar: bool
333359
) -> Param[BinaryParamContext]:
@@ -338,32 +364,29 @@ def make_param(
338364
left_sym = "x1"
339365
right_sym = "x2"
340366

341-
shared_dtypes = st.shared(dtypes_strat)
342367
if right_is_scalar:
343-
left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes(**shapes_kw))
344-
right_strat = shared_dtypes.flatmap(
345-
lambda d: xps.from_dtype(d, **finite_kw)
346-
)
368+
left_strat = xps.arrays(dtype=left_dtypes, shape=hh.shapes(**shapes_kw))
369+
right_strat = right_dtypes.flatmap(lambda d: xps.from_dtype(d, **finite_kw))
347370
else:
348371
if func_type is FuncType.IOP:
349372
shared_oneway_shapes = st.shared(oneway_broadcastable_shapes())
350373
left_strat = xps.arrays(
351-
dtype=shared_dtypes,
374+
dtype=left_dtypes,
352375
shape=shared_oneway_shapes.map(lambda S: S.result_shape),
353376
)
354377
right_strat = xps.arrays(
355-
dtype=shared_dtypes,
378+
dtype=right_dtypes,
356379
shape=shared_oneway_shapes.map(lambda S: S.input_shape),
357380
)
358381
else:
359382
mutual_shapes = st.shared(
360383
hh.mutually_broadcastable_shapes(2, **shapes_kw)
361384
)
362385
left_strat = xps.arrays(
363-
dtype=shared_dtypes, shape=mutual_shapes.map(lambda pair: pair[0])
386+
dtype=left_dtypes, shape=mutual_shapes.map(lambda pair: pair[0])
364387
)
365388
right_strat = xps.arrays(
366-
dtype=shared_dtypes, shape=mutual_shapes.map(lambda pair: pair[1])
389+
dtype=right_dtypes, shape=mutual_shapes.map(lambda pair: pair[1])
367390
)
368391

369392
if func_type is FuncType.FUNC:
@@ -540,7 +563,7 @@ def test_acosh(x):
540563
)
541564

542565

543-
@pytest.mark.parametrize("ctx,", make_binary_params("add", xps.numeric_dtypes()))
566+
@pytest.mark.parametrize("ctx,", make_binary_params("add", dh.numeric_dtypes))
544567
@given(data=st.data())
545568
def test_add(ctx, data):
546569
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -605,7 +628,7 @@ def test_atanh(x):
605628

606629

607630
@pytest.mark.parametrize(
608-
"ctx", make_binary_params("bitwise_and", boolean_and_all_integer_dtypes())
631+
"ctx", make_binary_params("bitwise_and", dh.bool_and_all_int_dtypes)
609632
)
610633
@given(data=st.data())
611634
def test_bitwise_and(ctx, data):
@@ -624,7 +647,7 @@ def test_bitwise_and(ctx, data):
624647

625648

626649
@pytest.mark.parametrize(
627-
"ctx", make_binary_params("bitwise_left_shift", all_integer_dtypes())
650+
"ctx", make_binary_params("bitwise_left_shift", dh.all_int_dtypes)
628651
)
629652
@given(data=st.data())
630653
def test_bitwise_left_shift(ctx, data):
@@ -664,7 +687,7 @@ def test_bitwise_invert(ctx, data):
664687

665688

666689
@pytest.mark.parametrize(
667-
"ctx", make_binary_params("bitwise_or", boolean_and_all_integer_dtypes())
690+
"ctx", make_binary_params("bitwise_or", dh.bool_and_all_int_dtypes)
668691
)
669692
@given(data=st.data())
670693
def test_bitwise_or(ctx, data):
@@ -683,7 +706,7 @@ def test_bitwise_or(ctx, data):
683706

684707

685708
@pytest.mark.parametrize(
686-
"ctx", make_binary_params("bitwise_right_shift", all_integer_dtypes())
709+
"ctx", make_binary_params("bitwise_right_shift", dh.all_int_dtypes)
687710
)
688711
@given(data=st.data())
689712
def test_bitwise_right_shift(ctx, data):
@@ -704,7 +727,7 @@ def test_bitwise_right_shift(ctx, data):
704727

705728

706729
@pytest.mark.parametrize(
707-
"ctx", make_binary_params("bitwise_xor", boolean_and_all_integer_dtypes())
730+
"ctx", make_binary_params("bitwise_xor", dh.bool_and_all_int_dtypes)
708731
)
709732
@given(data=st.data())
710733
def test_bitwise_xor(ctx, data):
@@ -746,7 +769,7 @@ def test_cosh(x):
746769
unary_assert_against_refimpl("cosh", x, out, math.cosh)
747770

748771

749-
@pytest.mark.parametrize("ctx", make_binary_params("divide", xps.floating_dtypes()))
772+
@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.float_dtypes))
750773
@given(data=st.data())
751774
def test_divide(ctx, data):
752775
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -769,7 +792,7 @@ def test_divide(ctx, data):
769792
)
770793

771794

772-
@pytest.mark.parametrize("ctx", make_binary_params("equal", xps.scalar_dtypes()))
795+
@pytest.mark.parametrize("ctx", make_binary_params("equal", dh.all_dtypes))
773796
@given(data=st.data())
774797
def test_equal(ctx, data):
775798
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -821,9 +844,7 @@ def test_floor(x):
821844
unary_assert_against_refimpl("floor", x, out, math.floor, strict_check=True)
822845

823846

824-
@pytest.mark.parametrize(
825-
"ctx", make_binary_params("floor_divide", xps.numeric_dtypes())
826-
)
847+
@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.numeric_dtypes))
827848
@given(data=st.data())
828849
def test_floor_divide(ctx, data):
829850
left = data.draw(
@@ -842,7 +863,7 @@ def test_floor_divide(ctx, data):
842863
binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv)
843864

844865

845-
@pytest.mark.parametrize("ctx", make_binary_params("greater", xps.numeric_dtypes()))
866+
@pytest.mark.parametrize("ctx", make_binary_params("greater", dh.numeric_dtypes))
846867
@given(data=st.data())
847868
def test_greater(ctx, data):
848869
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -862,9 +883,7 @@ def test_greater(ctx, data):
862883
)
863884

864885

865-
@pytest.mark.parametrize(
866-
"ctx", make_binary_params("greater_equal", xps.numeric_dtypes())
867-
)
886+
@pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.numeric_dtypes))
868887
@given(data=st.data())
869888
def test_greater_equal(ctx, data):
870889
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -908,7 +927,7 @@ def test_isnan(x):
908927
unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool)
909928

910929

911-
@pytest.mark.parametrize("ctx", make_binary_params("less", xps.numeric_dtypes()))
930+
@pytest.mark.parametrize("ctx", make_binary_params("less", dh.numeric_dtypes))
912931
@given(data=st.data())
913932
def test_less(ctx, data):
914933
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -928,7 +947,7 @@ def test_less(ctx, data):
928947
)
929948

930949

931-
@pytest.mark.parametrize("ctx", make_binary_params("less_equal", xps.numeric_dtypes()))
950+
@pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.numeric_dtypes))
932951
@given(data=st.data())
933952
def test_less_equal(ctx, data):
934953
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -1040,7 +1059,7 @@ def test_logical_xor(x1, x2):
10401059
)
10411060

10421061

1043-
@pytest.mark.parametrize("ctx", make_binary_params("multiply", xps.numeric_dtypes()))
1062+
@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))
10441063
@given(data=st.data())
10451064
def test_multiply(ctx, data):
10461065
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -1073,7 +1092,7 @@ def test_negative(ctx, data):
10731092
)
10741093

10751094

1076-
@pytest.mark.parametrize("ctx", make_binary_params("not_equal", xps.scalar_dtypes()))
1095+
@pytest.mark.parametrize("ctx", make_binary_params("not_equal", dh.all_dtypes))
10771096
@given(data=st.data())
10781097
def test_not_equal(ctx, data):
10791098
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -1105,7 +1124,7 @@ def test_positive(ctx, data):
11051124
ph.assert_array(ctx.func_name, out, x)
11061125

11071126

1108-
@pytest.mark.parametrize("ctx", make_binary_params("pow", xps.numeric_dtypes()))
1127+
@pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes))
11091128
@given(data=st.data())
11101129
def test_pow(ctx, data):
11111130
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -1129,7 +1148,7 @@ def test_pow(ctx, data):
11291148
)
11301149

11311150

1132-
@pytest.mark.parametrize("ctx", make_binary_params("remainder", xps.numeric_dtypes()))
1151+
@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.numeric_dtypes))
11331152
@given(data=st.data())
11341153
def test_remainder(ctx, data):
11351154
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -1200,7 +1219,7 @@ def test_sqrt(x):
12001219
)
12011220

12021221

1203-
@pytest.mark.parametrize("ctx", make_binary_params("subtract", xps.numeric_dtypes()))
1222+
@pytest.mark.parametrize("ctx", make_binary_params("subtract", dh.numeric_dtypes))
12041223
@given(data=st.data())
12051224
def test_subtract(ctx, data):
12061225
left = data.draw(ctx.left_strat, label=ctx.left_sym)

0 commit comments

Comments
 (0)