Skip to content

Commit e55e9a0

Browse files
committed
Fixed test_*_like methods to match spec
1 parent e2afeba commit e55e9a0

File tree

2 files changed

+31
-42
lines changed

2 files changed

+31
-42
lines changed

array_api_tests/hypothesis_helpers.py

-6
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,3 @@ def multiaxis_indices(draw, shapes):
242242
extra = draw(lists(one_of(integer_indices(sizes), slices(sizes)), min_size=0, max_size=3))
243243
res += extra
244244
return tuple(res)
245-
246-
247-
shared_optional_promotable_dtypes = one_of(
248-
none(),
249-
promotable_dtypes(shared_dtypes),
250-
)

array_api_tests/test_creation_functions.py

+31-36
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
zeros, zeros_like, isnan)
44
from .array_helpers import (is_integer_dtype, dtype_ranges,
55
assert_exactly_equal, isintegral, is_float_dtype)
6-
from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE, promotable_dtypes,
6+
from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE,
77
shapes, sizes, sqrt_sizes, shared_dtypes,
8-
scalars, xps, shared_optional_promotable_dtypes)
8+
scalars, xps)
99

1010
from hypothesis import assume, given
1111
from hypothesis.strategies import integers, floats, one_of, none, booleans, just, shared, composite
1212

1313

14+
optional_dtypes = none() | shared_dtypes
15+
shared_optional_dtypes = shared(optional_dtypes, key="optional_dtype")
16+
1417

1518
int_range = integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE)
1619
float_range = floats(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE,
@@ -82,25 +85,25 @@ def test_empty(shape, dtype):
8285

8386

8487
@given(
85-
a=xps.arrays(
86-
dtype=shared_dtypes,
88+
x=xps.arrays(
89+
dtype=dtypes,
8790
shape=shapes,
8891
),
89-
dtype=shared_optional_promotable_dtypes,
92+
dtype=optional_dtypes,
9093
)
91-
def test_empty_like(a, dtype):
94+
def test_empty_like(x, dtype):
9295
kwargs = {} if dtype is None else {'dtype': dtype}
9396

94-
a_like = empty_like(a, **kwargs)
97+
x_like = empty_like(x, **kwargs)
9598

9699
if dtype is None:
97100
# TODO: Should it actually match a.dtype?
98-
# assert is_float_dtype(a_like.dtype), "empty_like() should produce an array with the default floating point dtype"
101+
# assert is_float_dtype(x_like.dtype), "empty_like() should produce an array with the default floating point dtype"
99102
pass
100103
else:
101-
assert a_like.dtype == dtype, "empty_like() produced an array with an incorrect dtype"
104+
assert x_like.dtype == dtype, "empty_like() produced an array with an incorrect dtype"
102105

103-
assert a_like.shape == a.shape, "empty_like() produced an array with an incorrect shape"
106+
assert x_like.shape == x.shape, "empty_like() produced an array with an incorrect shape"
104107

105108

106109
# TODO: Use this method for all optional arguments
@@ -152,8 +155,6 @@ def test_full(shape, fill_value, dtype):
152155
else:
153156
assert all(equal(a, asarray(fill_value, **kwargs))), "full() array did not equal the fill value"
154157

155-
shared_optional_dtypes = shared(none() | shared_dtypes, key="optional_dtype")
156-
157158
@composite
158159
def fill_values(draw):
159160
dtype = draw(shared_optional_dtypes)
@@ -251,31 +252,28 @@ def test_ones(shape, dtype):
251252

252253

253254
@given(
254-
a=xps.arrays(
255-
dtype=shared_dtypes,
256-
shape=shapes,
257-
),
258-
dtype=shared_optional_promotable_dtypes,
255+
x=xps.arrays(dtype=dtypes, shape=shapes),
256+
dtype=optional_dtypes,
259257
)
260-
def test_ones_like(a, dtype):
258+
def test_ones_like(x, dtype):
261259
kwargs = {} if dtype is None else {'dtype': dtype}
262-
if kwargs is None or is_float_dtype(a.dtype):
260+
if kwargs is None or is_float_dtype(x.dtype):
263261
ONE = 1.0
264-
elif is_integer_dtype(a.dtype):
262+
elif is_integer_dtype(x.dtype):
265263
ONE = 1
266264
else:
267265
ONE = True
268266

269-
a_like = ones_like(a, **kwargs)
267+
x_like = ones_like(x, **kwargs)
270268

271269
if dtype is None:
272270
# TODO: Should it actually match a.dtype?
273271
pass
274272
else:
275-
assert a_like.dtype == dtype, "ones_like() produced an array with an incorrect dtype"
273+
assert x_like.dtype == dtype, "ones_like() produced an array with an incorrect dtype"
276274

277-
assert a_like.shape == a.shape, "ones_like() produced an array with an incorrect shape"
278-
assert all(equal(a_like, full((), ONE, dtype=a_like.dtype))), "ones_like() array did not equal 1"
275+
assert x_like.shape == x.shape, "ones_like() produced an array with an incorrect shape"
276+
assert all(equal(x_like, full((), ONE, dtype=x_like.dtype))), "ones_like() array did not equal 1"
279277

280278

281279
@given(shapes, one_of(none(), dtypes))
@@ -302,29 +300,26 @@ def test_zeros(shape, dtype):
302300

303301

304302
@given(
305-
a=xps.arrays(
306-
dtype=shared_dtypes,
307-
shape=shapes,
308-
),
309-
dtype=shared_optional_promotable_dtypes,
303+
x=xps.arrays(dtype=dtypes, shape=shapes),
304+
dtype=optional_dtypes,
310305
)
311-
def test_zeros_like(a, dtype):
306+
def test_zeros_like(x, dtype):
312307
kwargs = {} if dtype is None else {'dtype': dtype}
313-
if dtype is None or is_float_dtype(a.dtype):
308+
if dtype is None or is_float_dtype(x.dtype):
314309
ZERO = 0.0
315-
elif is_integer_dtype(a.dtype):
310+
elif is_integer_dtype(x.dtype):
316311
ZERO = 0
317312
else:
318313
ZERO = False
319314

320-
a_like = zeros_like(a, **kwargs)
315+
x_like = zeros_like(x, **kwargs)
321316

322317
if dtype is None:
323318
# TODO: Should it actually match a.dtype?
324319
pass
325320
else:
326-
assert a_like.dtype == dtype, "zeros_like() produced an array with an incorrect dtype"
321+
assert x_like.dtype == dtype, "zeros_like() produced an array with an incorrect dtype"
327322

328-
assert a_like.shape == a.shape, "zeros_like() produced an array with an incorrect shape"
329-
assert all(equal(a_like, full((), ZERO, dtype=a_like.dtype))), "zeros_like() array did not equal 0"
323+
assert x_like.shape == x.shape, "zeros_like() produced an array with an incorrect shape"
324+
assert all(equal(x_like, full((), ZERO, dtype=x_like.dtype))), "zeros_like() array did not equal 0"
330325

0 commit comments

Comments
 (0)