Skip to content

Commit 68d7560

Browse files
committed
Redefine custom shapes strategy using xps.array_shapes
1 parent b8399a1 commit 68d7560

File tree

3 files changed

+47
-19
lines changed

3 files changed

+47
-19
lines changed

array_api_tests/hypothesis_helpers.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
shared, tuples as hypotheses_tuples,
77
floats, just, composite, one_of, none,
88
booleans, SearchStrategy)
9-
from hypothesis.extra.numpy import mutually_broadcastable_shapes
109
from hypothesis.extra.array_api import make_strategies_namespace
1110
from hypothesis import assume
1211

@@ -111,29 +110,25 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)
111110
return lists(elements, min_size=min_size, max_size=max_size,
112111
unique_by=unique_by, unique=unique).map(tuple)
113112

114-
shapes = tuples(integers(0, 10)).filter(lambda shape: prod(shape) < MAX_ARRAY_SIZE)
115-
116113
# Use this to avoid memory errors with NumPy.
117114
# See https://github.com/numpy/numpy/issues/15753
118-
shapes = tuples(integers(0, 10)).filter(
119-
lambda shape: prod([i for i in shape if i]) < MAX_ARRAY_SIZE)
115+
shapes = xps.array_shapes(min_dims=0, min_side=0).filter(
116+
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
117+
)
120118

121-
two_mutually_broadcastable_shapes = mutually_broadcastable_shapes(num_shapes=2)\
119+
two_mutually_broadcastable_shapes = xps.mutually_broadcastable_shapes(num_shapes=2)\
122120
.map(lambda S: S.input_shapes)\
123-
.filter(lambda S: all(prod([i for i in shape if i]) < MAX_ARRAY_SIZE for shape in S))
121+
.filter(lambda S: all(prod(i for i in shape if i) < MAX_ARRAY_SIZE for shape in S))
124122

125123
@composite
126-
def two_broadcastable_shapes(draw, shapes=shapes):
124+
def two_broadcastable_shapes(draw):
127125
"""
128126
This will produce two shapes (shape1, shape2) such that shape2 can be
129127
broadcast to shape1.
130-
131128
"""
132129
from .test_broadcasting import broadcast_shapes
133-
134-
shape1, shape2 = draw(two_mutually_broadcastable_shapes)
135-
if broadcast_shapes(shape1, shape2) != shape1:
136-
assume(False)
130+
shape1, shape2 = draw(two_mutually_broadcastable_shapes)
131+
assume(broadcast_shapes(shape1, shape2) == shape1)
137132
return (shape1, shape2)
138133

139134
sizes = integers(0, MAX_ARRAY_SIZE)

array_api_tests/meta_tests/test_hypothesis_helpers.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
from math import prod
2+
13
import pytest
24
from hypothesis import given
35

46
from .. import _array_module as xp
57
from .._array_module import _UndefinedStub
68
from ..array_helpers import dtype_objects
7-
from ..hypothesis_helpers import (mutually_promotable_dtype_pairs,
8-
promotable_dtypes)
9+
from ..hypothesis_helpers import (MAX_ARRAY_SIZE,
10+
mutually_promotable_dtype_pairs,
11+
promotable_dtypes, shapes,
12+
two_broadcastable_shapes,
13+
two_mutually_broadcastable_shapes)
914

1015
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dtype_objects)
1116
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
@@ -27,3 +32,31 @@ def test_mutually_promotable_dtype_pairs(pairs):
2732
(xp.float64, xp.float64),
2833
)
2934

35+
36+
def valid_shape(shape) -> bool:
37+
return (
38+
all(isinstance(side, int) for side in shape)
39+
and all(side >= 0 for side in shape)
40+
and prod(shape) < MAX_ARRAY_SIZE
41+
)
42+
43+
44+
@given(shapes)
45+
def test_shapes(shape):
46+
assert valid_shape(shape)
47+
48+
49+
@given(two_mutually_broadcastable_shapes)
50+
def test_two_mutually_broadcastable_shapes(pair):
51+
for shape in pair:
52+
assert valid_shape(shape)
53+
54+
55+
@given(two_broadcastable_shapes())
56+
def test_two_broadcastable_shapes(pair):
57+
for shape in pair:
58+
assert valid_shape(shape)
59+
60+
from ..test_broadcasting import broadcast_shapes
61+
62+
assert broadcast_shapes(pair[0], pair[1]) == pair[0]

array_api_tests/test_creation_functions.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_empty(shape, dtype):
8282
@given(
8383
a=xps.arrays(
8484
dtype=shared_dtypes,
85-
shape=xps.array_shapes(),
85+
shape=shapes,
8686
),
8787
dtype=one_of(none(), promotable_dtypes(shared_dtypes)),
8888
)
@@ -153,7 +153,7 @@ def test_full(shape, fill_value, dtype):
153153
@given(
154154
a=xps.arrays(
155155
dtype=shared_dtypes,
156-
shape=xps.array_shapes(),
156+
shape=shapes,
157157
),
158158
fill_value=promotable_dtypes(shared_dtypes).flatmap(xps.from_dtype),
159159
dtype=one_of(none(), promotable_dtypes(shared_dtypes)),
@@ -245,7 +245,7 @@ def test_ones(shape, dtype):
245245
@given(
246246
a=xps.arrays(
247247
dtype=shared_dtypes,
248-
shape=xps.array_shapes(),
248+
shape=shapes,
249249
),
250250
dtype=one_of(none(), promotable_dtypes(shared_dtypes)),
251251
)
@@ -296,7 +296,7 @@ def test_zeros(shape, dtype):
296296
@given(
297297
a=xps.arrays(
298298
dtype=shared_dtypes,
299-
shape=xps.array_shapes(),
299+
shape=shapes,
300300
),
301301
dtype=one_of(none(), promotable_dtypes(shared_dtypes)),
302302
)

0 commit comments

Comments
 (0)