Skip to content

Commit 00e089b

Browse files
authored
Merge pull request #23 from honno/elementwise-ndarrays
Add multi-dimensional arrays support for existing elementwise tests
2 parents 0b41f77 + 3315e18 commit 00e089b

File tree

5 files changed

+538
-538
lines changed

5 files changed

+538
-538
lines changed

array_api_tests/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from hypothesis.extra.array_api import make_strategies_namespace
2+
3+
from . import _array_module as xp
4+
5+
6+
xps = make_strategies_namespace(xp)

array_api_tests/hypothesis_helpers.py

+15-20
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from hypothesis.strategies import (lists, integers, sampled_from,
88
shared, floats, just, composite, one_of,
99
none, booleans)
10-
from hypothesis.extra.array_api import make_strategies_namespace
1110

1211
from .pytest_helpers import nargs
1312
from .array_helpers import (dtype_ranges, integer_dtype_objects,
@@ -17,15 +16,12 @@
1716
ndindex)
1817
from ._array_module import (full, float32, float64, bool as bool_dtype,
1918
_UndefinedStub, eye, broadcast_to)
20-
from . import _array_module
2119
from . import _array_module as xp
20+
from . import xps
2221

2322
from .function_stubs import elementwise_functions
2423

2524

26-
xps = make_strategies_namespace(xp)
27-
28-
2925
# Set this to True to not fail tests just because a dtype isn't implemented.
3026
# If no compatible dtype is implemented for a given test, the test will fail
3127
# with a hypothesis health check error. Note that this functionality will not
@@ -79,10 +75,6 @@ def mutually_promotable_dtypes(draw, dtype_objects=dtype_objects):
7975
dtype_pairs = [(i, j) for i, j in dtype_pairs if i in dtype_objects and j in dtype_objects]
8076
return draw(sampled_from(dtype_pairs))
8177

82-
shared_mutually_promotable_dtype_pairs = shared(
83-
mutually_promotable_dtypes(), key="mutually_promotable_dtype_pair"
84-
)
85-
8678
# shared() allows us to draw either the function or the function name and they
8779
# will both correspond to the same function.
8880

@@ -93,10 +85,10 @@ def mutually_promotable_dtypes(draw, dtype_objects=dtype_objects):
9385
lambda func_name: nargs(func_name) > 1)
9486

9587
elementwise_function_objects = elementwise_functions_names.map(
96-
lambda i: getattr(_array_module, i))
88+
lambda i: getattr(xp, i))
9789
array_functions = elementwise_function_objects
9890
multiarg_array_functions = multiarg_array_functions_names.map(
99-
lambda i: getattr(_array_module, i))
91+
lambda i: getattr(xp, i))
10092

10193
# Limit the total size of an array shape
10294
MAX_ARRAY_SIZE = 10000
@@ -184,7 +176,6 @@ def two_broadcastable_shapes(draw):
184176
sizes = integers(0, MAX_ARRAY_SIZE)
185177
sqrt_sizes = integers(0, SQRT_MAX_ARRAY_SIZE)
186178

187-
# TODO: Generate general arrays here, rather than just scalars.
188179
numeric_arrays = xps.arrays(
189180
dtype=shared(xps.floating_dtypes(), key='dtypes'),
190181
shape=shared(xps.array_shapes(), key='shapes'),
@@ -295,14 +286,18 @@ def multiaxis_indices(draw, shapes):
295286
return tuple(res)
296287

297288

298-
shared_arrays1 = xps.arrays(
299-
dtype=shared_mutually_promotable_dtype_pairs.map(lambda pair: pair[0]),
300-
shape=shared(two_mutually_broadcastable_shapes, key="shape_pair").map(lambda pair: pair[0]),
301-
)
302-
shared_arrays2 = xps.arrays(
303-
dtype=shared_mutually_promotable_dtype_pairs.map(lambda pair: pair[1]),
304-
shape=shared(two_mutually_broadcastable_shapes, key="shape_pair").map(lambda pair: pair[1]),
305-
)
289+
def two_mutual_arrays(dtypes=dtype_objects):
290+
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes))
291+
mutual_shapes = shared(two_mutually_broadcastable_shapes)
292+
arrays1 = xps.arrays(
293+
dtype=mutual_dtypes.map(lambda pair: pair[0]),
294+
shape=mutual_shapes.map(lambda pair: pair[0]),
295+
)
296+
arrays2 = xps.arrays(
297+
dtype=mutual_dtypes.map(lambda pair: pair[1]),
298+
shape=mutual_shapes.map(lambda pair: pair[1]),
299+
)
300+
return arrays1, arrays2
306301

307302

308303
@composite

array_api_tests/meta_tests/test_hypothesis_helpers.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from math import prod
22

33
import pytest
4-
from hypothesis import given, strategies as st
4+
from hypothesis import given, strategies as st, settings
55

66
from .. import _array_module as xp
77
from .._array_module import _UndefinedStub
88
from .. import array_helpers as ah
99
from .. import hypothesis_helpers as hh
10+
from ..test_broadcasting import broadcast_shapes
11+
from ..test_elementwise_functions import sanity_check
1012

1113
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in ah.dtype_objects)
1214
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
@@ -44,20 +46,24 @@ def test_two_mutually_broadcastable_shapes(pair):
4446
def test_two_broadcastable_shapes(pair):
4547
for shape in pair:
4648
assert valid_shape(shape)
49+
assert broadcast_shapes(pair[0], pair[1]) == pair[0]
4750

48-
from ..test_broadcasting import broadcast_shapes
4951

50-
assert broadcast_shapes(pair[0], pair[1]) == pair[0]
52+
@given(*hh.two_mutual_arrays())
53+
def test_two_mutual_arrays(x1, x2):
54+
sanity_check(x1, x2)
55+
assert broadcast_shapes(x1.shape, x2.shape) in (x1.shape, x2.shape)
5156

5257

5358
def test_kwargs():
5459
results = []
5560

5661
@given(hh.kwargs(n=st.integers(0, 10), c=st.from_regex("[a-f]")))
62+
@settings(max_examples=100)
5763
def run(kw):
5864
results.append(kw)
59-
6065
run()
66+
6167
assert all(isinstance(kw, dict) for kw in results)
6268
for size in [0, 1, 2]:
6369
assert any(len(kw) == size for kw in results)

array_api_tests/test_creation_functions.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
assert_exactly_equal, isintegral, is_float_dtype)
99
from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE,
1010
shapes, sizes, sqrt_sizes, shared_dtypes,
11-
scalars, xps, kwargs)
11+
scalars, kwargs)
12+
from . import xps
1213

1314
from hypothesis import assume, given
1415
from hypothesis.strategies import integers, floats, one_of, none, booleans, just, shared, composite

0 commit comments

Comments
 (0)