Skip to content

Commit 63d66bc

Browse files
committed
promotable_dtypes() strategy
* Used in test_*_like() tests * Refactors construction of dtype_pairs table * mutually_promotable_dtypes -> mutually_promotable_dtype_pairs for clarity * Rudimentary meta tests
1 parent c46b51d commit 63d66bc

File tree

4 files changed

+78
-34
lines changed

4 files changed

+78
-34
lines changed

array_api_tests/hypothesis_helpers.py

+20-13
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
integer_or_boolean_dtype_objects, dtype_objects)
1818
from ._array_module import (ones, full, float32, float64, bool as bool_dtype, _UndefinedStub)
1919
from . import _array_module
20-
from ._array_module import mod as xp
20+
from . import _array_module as xp
2121

2222
from .function_stubs import elementwise_functions
2323

@@ -50,8 +50,7 @@
5050

5151
shared_dtypes = shared(dtypes)
5252

53-
@composite
54-
def mutually_promotable_dtypes(draw, dtype_objects=dtype_objects):
53+
def make_dtype_pairs():
5554
from .test_type_promotion import dtype_mapping, promotion_table
5655
# sort for shrinking (sampled_from shrinks to the earlier elements in the
5756
# list). Give pairs of the same dtypes first, then smaller dtypes,
@@ -61,17 +60,25 @@ def mutually_promotable_dtypes(draw, dtype_objects=dtype_objects):
6160
# pairs (XXX: Can we redesign the strategies so that they can prefer
6261
# shrinking dtypes over values?)
6362
sorted_table = sorted(promotion_table)
64-
sorted_table = sorted(sorted_table, key=lambda ij: -1 if ij[0] == ij[1] else sorted_table.index(ij))
65-
dtype_pairs = [(dtype_mapping[i], dtype_mapping[j]) for i, j in
66-
sorted_table]
67-
68-
filtered_dtype_pairs = [(i, j) for i, j in dtype_pairs if i in
69-
dtype_objects and j in dtype_objects]
63+
sorted_table = sorted(
64+
sorted_table, key=lambda ij: -1 if ij[0] == ij[1] else sorted_table.index(ij)
65+
)
66+
dtype_pairs = [(dtype_mapping[i], dtype_mapping[j]) for i, j in sorted_table]
7067
if FILTER_UNDEFINED_DTYPES:
71-
filtered_dtype_pairs = [(i, j) for i, j in filtered_dtype_pairs
72-
if not isinstance(i, _UndefinedStub)
73-
and not isinstance(j, _UndefinedStub)]
74-
return draw(sampled_from(filtered_dtype_pairs))
68+
dtype_pairs = [(i, j) for i, j in dtype_pairs
69+
if not isinstance(i, _UndefinedStub)
70+
and not isinstance(j, _UndefinedStub)]
71+
return dtype_pairs
72+
73+
def promotable_dtypes(dtype):
74+
dtype_pairs = make_dtype_pairs()
75+
dtypes = [j for i, j in dtype_pairs if i == dtype]
76+
return sampled_from(dtypes)
77+
78+
def mutually_promotable_dtype_pairs(dtype_objects=dtype_objects):
79+
dtype_pairs = make_dtype_pairs()
80+
dtype_pairs = [(i, j) for i, j in dtype_pairs if i in dtype_objects and j in dtype_objects]
81+
return sampled_from(dtype_pairs)
7582

7683
# shared() allows us to draw either the function or the function name and they
7784
# will both correspond to the same function.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
from hypothesis import given, strategies as st
3+
4+
from .. import _array_module as xp
5+
from .._array_module import _UndefinedStub
6+
from ..array_helpers import dtype_objects
7+
from ..hypothesis_helpers import (mutually_promotable_dtype_pairs,
8+
promotable_dtypes)
9+
10+
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dtype_objects)
11+
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
12+
13+
14+
def test_promotable_dtypes():
15+
dtypes = set()
16+
@given(promotable_dtypes(xp.uint16))
17+
def run(dtype):
18+
dtypes.add(dtype)
19+
run()
20+
assert dtypes == {
21+
xp.uint8, xp.uint16, xp.uint32, xp.uint64, xp.int8, xp.int16, xp.int32, xp.int64
22+
}
23+
24+
25+
def test_mutually_promotable_dtype_pairs():
26+
pairs = set()
27+
@given(mutually_promotable_dtype_pairs([xp.float32, xp.float64]))
28+
def run(pair):
29+
pairs.add(pair)
30+
run()
31+
assert pairs == {
32+
(xp.float32, xp.float32),
33+
(xp.float32, xp.float64),
34+
(xp.float64, xp.float32),
35+
(xp.float64, xp.float64),
36+
}
37+

array_api_tests/test_creation_functions.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
zeros, zeros_like, isnan)
44
from .array_helpers import (is_integer_dtype, dtype_ranges,
55
assert_exactly_equal, isintegral, is_float_dtype)
6-
from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE,
6+
from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE, promotable_dtypes,
77
shapes, sizes, sqrt_sizes, shared_dtypes,
88
scalars, xps)
99

@@ -84,7 +84,7 @@ def test_empty(shape, dtype):
8484
dtype=shared_dtypes,
8585
shape=xps.array_shapes(),
8686
),
87-
dtype=one_of(none(), shared_dtypes),
87+
dtype=one_of(none(), shared_dtypes.flatmap(promotable_dtypes)),
8888
)
8989
def test_empty_like(a, dtype):
9090
kwargs = {} if dtype is None else {'dtype': dtype}
@@ -96,7 +96,7 @@ def test_empty_like(a, dtype):
9696
# assert is_float_dtype(a_like.dtype), "empty_like() should produce an array with the default floating point dtype"
9797
pass
9898
else:
99-
assert a_like.dtype == a.dtype, "empty_like() produced an array with an incorrect dtype"
99+
assert a_like.dtype == dtype, "empty_like() produced an array with an incorrect dtype"
100100

101101
assert a_like.shape == a.shape, "empty_like() produced an array with an incorrect shape"
102102

@@ -155,8 +155,8 @@ def test_full(shape, fill_value, dtype):
155155
dtype=shared_dtypes,
156156
shape=xps.array_shapes(),
157157
),
158-
fill_value=shared_dtypes.flatmap(xps.from_dtype),
159-
dtype=one_of(none(), shared_dtypes),
158+
fill_value=shared_dtypes.flatmap(promotable_dtypes).flatmap(xps.from_dtype),
159+
dtype=one_of(none(), shared_dtypes.flatmap(promotable_dtypes)),
160160
)
161161
def test_full_like(a, fill_value, dtype):
162162
kwargs = {} if dtype is None else {'dtype': dtype}
@@ -167,13 +167,13 @@ def test_full_like(a, fill_value, dtype):
167167
# TODO: Should it actually match a.dtype?
168168
pass
169169
else:
170-
assert a_like.dtype == a.dtype
170+
assert a_like.dtype == dtype
171171

172172
assert a_like.shape == a.shape, "full_like() produced an array with incorrect shape"
173173
if is_float_dtype(a_like.dtype) and isnan(asarray(fill_value)):
174174
assert all(isnan(a_like)), "full_like() array did not equal the fill value"
175175
else:
176-
assert all(equal(a_like, asarray(fill_value, dtype=a.dtype))), "full_like() array did not equal the fill value"
176+
assert all(equal(a_like, asarray(fill_value, dtype=a_like.dtype))), "full_like() array did not equal the fill value"
177177

178178

179179
@given(scalars(shared_dtypes, finite=True),
@@ -247,7 +247,7 @@ def test_ones(shape, dtype):
247247
dtype=shared_dtypes,
248248
shape=xps.array_shapes(),
249249
),
250-
dtype=one_of(none(), shared_dtypes),
250+
dtype=one_of(none(), shared_dtypes.flatmap(promotable_dtypes)),
251251
)
252252
def test_ones_like(a, dtype):
253253
kwargs = {} if dtype is None else {'dtype': dtype}
@@ -260,11 +260,11 @@ def test_ones_like(a, dtype):
260260

261261
a_like = ones_like(a, **kwargs)
262262

263-
if kwargs is None:
263+
if dtype is None:
264264
# TODO: Should it actually match a.dtype?
265265
pass
266266
else:
267-
assert a_like.dtype == a.dtype, "ones_like() produced an array with an incorrect dtype"
267+
assert a_like.dtype == dtype, "ones_like() produced an array with an incorrect dtype"
268268

269269
assert a_like.shape == a.shape, "ones_like() produced an array with an incorrect shape"
270270
assert all(equal(a_like, full((), ONE, dtype=a_like.dtype))), "ones_like() array did not equal 1"
@@ -298,7 +298,7 @@ def test_zeros(shape, dtype):
298298
dtype=shared_dtypes,
299299
shape=xps.array_shapes(),
300300
),
301-
dtype=one_of(none(), shared_dtypes),
301+
dtype=one_of(none(), shared_dtypes.flatmap(promotable_dtypes)),
302302
)
303303
def test_zeros_like(a, dtype):
304304
kwargs = {} if dtype is None else {'dtype': dtype}
@@ -311,11 +311,11 @@ def test_zeros_like(a, dtype):
311311

312312
a_like = zeros_like(a, **kwargs)
313313

314-
if kwargs is None:
314+
if dtype is None:
315315
# TODO: Should it actually match a.dtype?
316316
pass
317317
else:
318-
assert a_like.dtype == a.dtype, "zeros_like() produced an array with an incorrect dtype"
318+
assert a_like.dtype == dtype, "zeros_like() produced an array with an incorrect dtype"
319319

320320
assert a_like.shape == a.shape, "zeros_like() produced an array with an incorrect shape"
321321
assert all(equal(a_like, full((), ZERO, dtype=a_like.dtype))), "zeros_like() array did not equal 0"

array_api_tests/test_elementwise_functions.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
integer_or_boolean_dtype_objects,
2626
boolean_dtype_objects, floating_dtypes,
2727
numeric_dtypes, integer_or_boolean_dtypes,
28-
boolean_dtypes, mutually_promotable_dtypes,
28+
boolean_dtypes, mutually_promotable_dtype_pairs,
2929
array_scalars, xps)
3030
from .array_helpers import (assert_exactly_equal, negative,
3131
positive_mathematical_sign,
@@ -50,17 +50,17 @@
5050
integer_or_boolean_scalars = array_scalars(integer_or_boolean_dtypes)
5151
boolean_scalars = array_scalars(boolean_dtypes)
5252

53-
two_integer_dtypes = mutually_promotable_dtypes(integer_dtype_objects)
54-
two_floating_dtypes = mutually_promotable_dtypes(floating_dtype_objects)
55-
two_numeric_dtypes = mutually_promotable_dtypes(numeric_dtype_objects)
56-
two_integer_or_boolean_dtypes = mutually_promotable_dtypes(integer_or_boolean_dtype_objects)
57-
two_boolean_dtypes = mutually_promotable_dtypes(boolean_dtype_objects)
58-
two_any_dtypes = mutually_promotable_dtypes()
53+
two_integer_dtypes = mutually_promotable_dtype_pairs(integer_dtype_objects)
54+
two_floating_dtypes = mutually_promotable_dtype_pairs(floating_dtype_objects)
55+
two_numeric_dtypes = mutually_promotable_dtype_pairs(numeric_dtype_objects)
56+
two_integer_or_boolean_dtypes = mutually_promotable_dtype_pairs(integer_or_boolean_dtype_objects)
57+
two_boolean_dtypes = mutually_promotable_dtype_pairs(boolean_dtype_objects)
58+
two_any_dtypes = mutually_promotable_dtype_pairs()
5959

6060
@composite
6161
def two_array_scalars(draw, dtype1, dtype2):
6262
# two_dtypes should be a strategy that returns two dtypes (like
63-
# mutually_promotable_dtypes())
63+
# mutually_promotable_dtype_pairs())
6464
return draw(array_scalars(just(dtype1))), draw(array_scalars(just(dtype2)))
6565

6666
def sanity_check(x1, x2):

0 commit comments

Comments
 (0)