Skip to content

Commit 7754038

Browse files
committed
Test broadcastable shapes for in-place operators
1 parent 3c85cae commit 7754038

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

array_api_tests/meta/test_utils.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
from hypothesis import strategies as st
44

55
from .. import _array_module as xp
6-
from .. import xps
76
from .. import shape_helpers as sh
7+
from .. import xps
88
from ..test_creation_functions import frange
99
from ..test_manipulation_functions import roll_ndindex
10-
from ..test_operators_and_elementwise_functions import mock_int_dtype
10+
from ..test_operators_and_elementwise_functions import (
11+
mock_int_dtype,
12+
oneway_broadcastable_shapes,
13+
)
1114
from ..test_signatures import extension_module
1215

1316

@@ -115,3 +118,8 @@ def test_int_to_dtype(x, dtype):
115118
except OverflowError:
116119
reject()
117120
assert mock_int_dtype(x, dtype) == d
121+
122+
123+
@given(oneway_broadcastable_shapes())
124+
def test_oneway_broadcastable_shapes(S):
125+
assert sh.broadcast_shapes(*S) == S.result_shape

array_api_tests/test_operators_and_elementwise_functions.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,26 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
3030
return xps.boolean_dtypes() | all_integer_dtypes()
3131

3232

33+
class OnewayBroadcastableShapes(NamedTuple):
34+
input_shape: Shape
35+
result_shape: Shape
36+
37+
38+
@st.composite
39+
def oneway_broadcastable_shapes(draw) -> st.SearchStrategy[OnewayBroadcastableShapes]:
40+
"""Return a strategy for input shapes that broadcast to result shapes."""
41+
result_shape = draw(hh.shapes(min_side=1))
42+
input_shape = draw(
43+
xps.broadcastable_shapes(
44+
result_shape,
45+
# Override defaults so bad shapes are less likely to be generated.
46+
max_side=None if result_shape == () else max(result_shape),
47+
max_dims=len(result_shape),
48+
).filter(lambda s: sh.broadcast_shapes(result_shape, s) == result_shape)
49+
)
50+
return OnewayBroadcastableShapes(input_shape, result_shape)
51+
52+
3353
def mock_int_dtype(n: int, dtype: DataType) -> int:
3454
"""Returns equivalent of `n` that mocks `dtype` behaviour."""
3555
nbits = dh.dtype_nbits[dtype]
@@ -326,9 +346,15 @@ def make_param(
326346
)
327347
else:
328348
if func_type is FuncType.IOP:
329-
shared_shapes = st.shared(hh.shapes(**shapes_kw))
330-
left_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes)
331-
right_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes)
349+
shared_oneway_shapes = st.shared(oneway_broadcastable_shapes())
350+
left_strat = xps.arrays(
351+
dtype=shared_dtypes,
352+
shape=shared_oneway_shapes.map(lambda S: S.result_shape),
353+
)
354+
right_strat = xps.arrays(
355+
dtype=shared_dtypes,
356+
shape=shared_oneway_shapes.map(lambda S: S.input_shape),
357+
)
332358
else:
333359
mutual_shapes = st.shared(
334360
hh.mutually_broadcastable_shapes(2, **shapes_kw)

0 commit comments

Comments
 (0)