Skip to content

Commit 17e7537

Browse files
authored
Merge pull request #91 from honno/inplace-shapes
Inplace shapes
2 parents 0f63fab + 263b764 commit 17e7537

File tree

2 files changed

+95
-35
lines changed

2 files changed

+95
-35
lines changed

array_api_tests/meta/test_utils.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
from hypothesis import strategies as st
44

55
from .. import _array_module as xp
6-
from .. import xps
6+
from .. import dtype_helpers as dh
77
from .. import shape_helpers as sh
8+
from .. import xps
89
from ..test_creation_functions import frange
910
from ..test_manipulation_functions import roll_ndindex
10-
from ..test_operators_and_elementwise_functions import mock_int_dtype
11+
from ..test_operators_and_elementwise_functions import (
12+
mock_int_dtype,
13+
oneway_broadcastable_shapes,
14+
oneway_promotable_dtypes,
15+
)
1116
from ..test_signatures import extension_module
1217

1318

@@ -115,3 +120,13 @@ def test_int_to_dtype(x, dtype):
115120
except OverflowError:
116121
reject()
117122
assert mock_int_dtype(x, dtype) == d
123+
124+
125+
@given(oneway_promotable_dtypes(dh.all_dtypes))
126+
def test_oneway_promotable_dtypes(D):
127+
assert D.result_dtype == dh.result_type(*D)
128+
129+
130+
@given(oneway_broadcastable_shapes())
131+
def test_oneway_broadcastable_shapes(S):
132+
assert S.result_shape == sh.broadcast_shapes(*S)

array_api_tests/test_operators_and_elementwise_functions.py

+78-33
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,46 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
3030
return xps.boolean_dtypes() | all_integer_dtypes()
3131

3232

33+
class OnewayPromotableDtypes(NamedTuple):
34+
input_dtype: DataType
35+
result_dtype: DataType
36+
37+
38+
@st.composite
39+
def oneway_promotable_dtypes(
40+
draw, dtypes: List[DataType]
41+
) -> st.SearchStrategy[OnewayPromotableDtypes]:
42+
"""Return a strategy for input dtypes that promote to result dtypes."""
43+
d1, d2 = draw(hh.mutually_promotable_dtypes(dtypes=dtypes))
44+
result_dtype = dh.result_type(d1, d2)
45+
if d1 == result_dtype:
46+
return OnewayPromotableDtypes(d2, d1)
47+
elif d2 == result_dtype:
48+
return OnewayPromotableDtypes(d1, d2)
49+
else:
50+
reject()
51+
52+
53+
class OnewayBroadcastableShapes(NamedTuple):
54+
input_shape: Shape
55+
result_shape: Shape
56+
57+
58+
@st.composite
59+
def oneway_broadcastable_shapes(draw) -> st.SearchStrategy[OnewayBroadcastableShapes]:
60+
"""Return a strategy for input shapes that broadcast to result shapes."""
61+
result_shape = draw(hh.shapes(min_side=1))
62+
input_shape = draw(
63+
xps.broadcastable_shapes(
64+
result_shape,
65+
# Override defaults so bad shapes are less likely to be generated.
66+
max_side=None if result_shape == () else max(result_shape),
67+
max_dims=len(result_shape),
68+
).filter(lambda s: sh.broadcast_shapes(result_shape, s) == result_shape)
69+
)
70+
return OnewayBroadcastableShapes(input_shape, result_shape)
71+
72+
3373
def mock_int_dtype(n: int, dtype: DataType) -> int:
3474
"""Returns equivalent of `n` that mocks `dtype` behaviour."""
3575
nbits = dh.dtype_nbits[dtype]
@@ -306,8 +346,14 @@ def __repr__(self):
306346

307347

308348
def make_binary_params(
309-
elwise_func_name: str, dtypes_strat: st.SearchStrategy[DataType]
349+
elwise_func_name: str, dtypes: List[DataType]
310350
) -> List[Param[BinaryParamContext]]:
351+
if hh.FILTER_UNDEFINED_DTYPES:
352+
dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)]
353+
shared_oneway_dtypes = st.shared(oneway_promotable_dtypes(dtypes))
354+
left_dtypes = shared_oneway_dtypes.map(lambda D: D.result_dtype)
355+
right_dtypes = shared_oneway_dtypes.map(lambda D: D.input_dtype)
356+
311357
def make_param(
312358
func_name: str, func_type: FuncType, right_is_scalar: bool
313359
) -> Param[BinaryParamContext]:
@@ -318,26 +364,29 @@ def make_param(
318364
left_sym = "x1"
319365
right_sym = "x2"
320366

321-
shared_dtypes = st.shared(dtypes_strat)
322367
if right_is_scalar:
323-
left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes(**shapes_kw))
324-
right_strat = shared_dtypes.flatmap(
325-
lambda d: xps.from_dtype(d, **finite_kw)
326-
)
368+
left_strat = xps.arrays(dtype=left_dtypes, shape=hh.shapes(**shapes_kw))
369+
right_strat = right_dtypes.flatmap(lambda d: xps.from_dtype(d, **finite_kw))
327370
else:
328371
if func_type is FuncType.IOP:
329-
shared_shapes = st.shared(hh.shapes(**shapes_kw))
330-
left_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes)
331-
right_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes)
372+
shared_oneway_shapes = st.shared(oneway_broadcastable_shapes())
373+
left_strat = xps.arrays(
374+
dtype=left_dtypes,
375+
shape=shared_oneway_shapes.map(lambda S: S.result_shape),
376+
)
377+
right_strat = xps.arrays(
378+
dtype=right_dtypes,
379+
shape=shared_oneway_shapes.map(lambda S: S.input_shape),
380+
)
332381
else:
333382
mutual_shapes = st.shared(
334383
hh.mutually_broadcastable_shapes(2, **shapes_kw)
335384
)
336385
left_strat = xps.arrays(
337-
dtype=shared_dtypes, shape=mutual_shapes.map(lambda pair: pair[0])
386+
dtype=left_dtypes, shape=mutual_shapes.map(lambda pair: pair[0])
338387
)
339388
right_strat = xps.arrays(
340-
dtype=shared_dtypes, shape=mutual_shapes.map(lambda pair: pair[1])
389+
dtype=right_dtypes, shape=mutual_shapes.map(lambda pair: pair[1])
341390
)
342391

343392
if func_type is FuncType.FUNC:
@@ -514,7 +563,7 @@ def test_acosh(x):
514563
)
515564

516565

517-
@pytest.mark.parametrize("ctx,", make_binary_params("add", xps.numeric_dtypes()))
566+
@pytest.mark.parametrize("ctx,", make_binary_params("add", dh.numeric_dtypes))
518567
@given(data=st.data())
519568
def test_add(ctx, data):
520569
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -579,7 +628,7 @@ def test_atanh(x):
579628

580629

581630
@pytest.mark.parametrize(
582-
"ctx", make_binary_params("bitwise_and", boolean_and_all_integer_dtypes())
631+
"ctx", make_binary_params("bitwise_and", dh.bool_and_all_int_dtypes)
583632
)
584633
@given(data=st.data())
585634
def test_bitwise_and(ctx, data):
@@ -598,7 +647,7 @@ def test_bitwise_and(ctx, data):
598647

599648

600649
@pytest.mark.parametrize(
601-
"ctx", make_binary_params("bitwise_left_shift", all_integer_dtypes())
650+
"ctx", make_binary_params("bitwise_left_shift", dh.all_int_dtypes)
602651
)
603652
@given(data=st.data())
604653
def test_bitwise_left_shift(ctx, data):
@@ -638,7 +687,7 @@ def test_bitwise_invert(ctx, data):
638687

639688

640689
@pytest.mark.parametrize(
641-
"ctx", make_binary_params("bitwise_or", boolean_and_all_integer_dtypes())
690+
"ctx", make_binary_params("bitwise_or", dh.bool_and_all_int_dtypes)
642691
)
643692
@given(data=st.data())
644693
def test_bitwise_or(ctx, data):
@@ -657,7 +706,7 @@ def test_bitwise_or(ctx, data):
657706

658707

659708
@pytest.mark.parametrize(
660-
"ctx", make_binary_params("bitwise_right_shift", all_integer_dtypes())
709+
"ctx", make_binary_params("bitwise_right_shift", dh.all_int_dtypes)
661710
)
662711
@given(data=st.data())
663712
def test_bitwise_right_shift(ctx, data):
@@ -678,7 +727,7 @@ def test_bitwise_right_shift(ctx, data):
678727

679728

680729
@pytest.mark.parametrize(
681-
"ctx", make_binary_params("bitwise_xor", boolean_and_all_integer_dtypes())
730+
"ctx", make_binary_params("bitwise_xor", dh.bool_and_all_int_dtypes)
682731
)
683732
@given(data=st.data())
684733
def test_bitwise_xor(ctx, data):
@@ -720,7 +769,7 @@ def test_cosh(x):
720769
unary_assert_against_refimpl("cosh", x, out, math.cosh)
721770

722771

723-
@pytest.mark.parametrize("ctx", make_binary_params("divide", xps.floating_dtypes()))
772+
@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.float_dtypes))
724773
@given(data=st.data())
725774
def test_divide(ctx, data):
726775
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -743,7 +792,7 @@ def test_divide(ctx, data):
743792
)
744793

745794

746-
@pytest.mark.parametrize("ctx", make_binary_params("equal", xps.scalar_dtypes()))
795+
@pytest.mark.parametrize("ctx", make_binary_params("equal", dh.all_dtypes))
747796
@given(data=st.data())
748797
def test_equal(ctx, data):
749798
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -795,9 +844,7 @@ def test_floor(x):
795844
unary_assert_against_refimpl("floor", x, out, math.floor, strict_check=True)
796845

797846

798-
@pytest.mark.parametrize(
799-
"ctx", make_binary_params("floor_divide", xps.numeric_dtypes())
800-
)
847+
@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.numeric_dtypes))
801848
@given(data=st.data())
802849
def test_floor_divide(ctx, data):
803850
left = data.draw(
@@ -816,7 +863,7 @@ def test_floor_divide(ctx, data):
816863
binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv)
817864

818865

819-
@pytest.mark.parametrize("ctx", make_binary_params("greater", xps.numeric_dtypes()))
866+
@pytest.mark.parametrize("ctx", make_binary_params("greater", dh.numeric_dtypes))
820867
@given(data=st.data())
821868
def test_greater(ctx, data):
822869
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -836,9 +883,7 @@ def test_greater(ctx, data):
836883
)
837884

838885

839-
@pytest.mark.parametrize(
840-
"ctx", make_binary_params("greater_equal", xps.numeric_dtypes())
841-
)
886+
@pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.numeric_dtypes))
842887
@given(data=st.data())
843888
def test_greater_equal(ctx, data):
844889
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -882,7 +927,7 @@ def test_isnan(x):
882927
unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool)
883928

884929

885-
@pytest.mark.parametrize("ctx", make_binary_params("less", xps.numeric_dtypes()))
930+
@pytest.mark.parametrize("ctx", make_binary_params("less", dh.numeric_dtypes))
886931
@given(data=st.data())
887932
def test_less(ctx, data):
888933
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -902,7 +947,7 @@ def test_less(ctx, data):
902947
)
903948

904949

905-
@pytest.mark.parametrize("ctx", make_binary_params("less_equal", xps.numeric_dtypes()))
950+
@pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.numeric_dtypes))
906951
@given(data=st.data())
907952
def test_less_equal(ctx, data):
908953
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -1014,7 +1059,7 @@ def test_logical_xor(x1, x2):
10141059
)
10151060

10161061

1017-
@pytest.mark.parametrize("ctx", make_binary_params("multiply", xps.numeric_dtypes()))
1062+
@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))
10181063
@given(data=st.data())
10191064
def test_multiply(ctx, data):
10201065
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -1047,7 +1092,7 @@ def test_negative(ctx, data):
10471092
)
10481093

10491094

1050-
@pytest.mark.parametrize("ctx", make_binary_params("not_equal", xps.scalar_dtypes()))
1095+
@pytest.mark.parametrize("ctx", make_binary_params("not_equal", dh.all_dtypes))
10511096
@given(data=st.data())
10521097
def test_not_equal(ctx, data):
10531098
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -1079,7 +1124,7 @@ def test_positive(ctx, data):
10791124
ph.assert_array(ctx.func_name, out, x)
10801125

10811126

1082-
@pytest.mark.parametrize("ctx", make_binary_params("pow", xps.numeric_dtypes()))
1127+
@pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes))
10831128
@given(data=st.data())
10841129
def test_pow(ctx, data):
10851130
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -1103,7 +1148,7 @@ def test_pow(ctx, data):
11031148
)
11041149

11051150

1106-
@pytest.mark.parametrize("ctx", make_binary_params("remainder", xps.numeric_dtypes()))
1151+
@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.numeric_dtypes))
11071152
@given(data=st.data())
11081153
def test_remainder(ctx, data):
11091154
left = data.draw(ctx.left_strat, label=ctx.left_sym)
@@ -1174,7 +1219,7 @@ def test_sqrt(x):
11741219
)
11751220

11761221

1177-
@pytest.mark.parametrize("ctx", make_binary_params("subtract", xps.numeric_dtypes()))
1222+
@pytest.mark.parametrize("ctx", make_binary_params("subtract", dh.numeric_dtypes))
11781223
@given(data=st.data())
11791224
def test_subtract(ctx, data):
11801225
left = data.draw(ctx.left_strat, label=ctx.left_sym)

0 commit comments

Comments
 (0)