Skip to content

Commit 8b11476

Browse files
committed
Implement kwargs strategy, and use it in some creation tests
1 parent 70099e5 commit 8b11476

File tree

3 files changed

+114
-114
lines changed

3 files changed

+114
-114
lines changed

array_api_tests/hypothesis_helpers.py

+9
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,12 @@ def multiaxis_indices(draw, shapes):
251251
dtype=shared_mutually_promotable_dtype_pairs.map(lambda pair: pair[1]),
252252
shape=shared(two_mutually_broadcastable_shapes, key="shape_pair").map(lambda pair: pair[1]),
253253
)
254+
255+
256+
@composite
257+
def kwargs(draw, **kw):
258+
result = {}
259+
for k, strat in kw.items():
260+
if draw(booleans()):
261+
result[k] = draw(strat)
262+
return result
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,18 @@
11
from math import prod
22

33
import pytest
4-
from hypothesis import given
4+
from hypothesis import given, strategies as st, assume
55

66
from .. import _array_module as xp
77
from .._array_module import _UndefinedStub
8-
from ..array_helpers import dtype_objects
9-
from ..hypothesis_helpers import (MAX_ARRAY_SIZE,
10-
mutually_promotable_dtypes,
11-
shapes, two_broadcastable_shapes,
12-
two_mutually_broadcastable_shapes)
8+
from .. import array_helpers as ah
9+
from .. import hypothesis_helpers as hh
1310

14-
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dtype_objects)
11+
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in ah.dtype_objects)
1512
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
1613

1714

18-
@given(mutually_promotable_dtypes([xp.float32, xp.float64]))
15+
@given(hh.mutually_promotable_dtypes([xp.float32, xp.float64]))
1916
def test_mutually_promotable_dtypes(pairs):
2017
assert pairs in (
2118
(xp.float32, xp.float32),
@@ -29,26 +26,48 @@ def valid_shape(shape) -> bool:
2926
return (
3027
all(isinstance(side, int) for side in shape)
3128
and all(side >= 0 for side in shape)
32-
and prod(shape) < MAX_ARRAY_SIZE
29+
and prod(shape) < hh.MAX_ARRAY_SIZE
3330
)
3431

3532

36-
@given(shapes)
33+
@given(hh.shapes)
3734
def test_shapes(shape):
3835
assert valid_shape(shape)
3936

4037

41-
@given(two_mutually_broadcastable_shapes)
38+
@given(hh.two_mutually_broadcastable_shapes)
4239
def test_two_mutually_broadcastable_shapes(pair):
4340
for shape in pair:
4441
assert valid_shape(shape)
4542

4643

47-
@given(two_broadcastable_shapes())
44+
@given(hh.two_broadcastable_shapes())
4845
def test_two_broadcastable_shapes(pair):
4946
for shape in pair:
5047
assert valid_shape(shape)
5148

5249
from ..test_broadcasting import broadcast_shapes
5350

5451
assert broadcast_shapes(pair[0], pair[1]) == pair[0]
52+
53+
54+
def test_kwargs():
55+
results = []
56+
57+
@given(hh.kwargs(n=st.integers(0, 10), c=st.from_regex("[a-f]")))
58+
def run(kw):
59+
results.append(kw)
60+
61+
run()
62+
assert all(isinstance(kw, dict) for kw in results)
63+
for size in [0, 1, 2]:
64+
assert any(len(kw) == size for kw in results)
65+
66+
n_results = [kw for kw in results if "n" in kw]
67+
assert len(n_results) > 0
68+
assert all(isinstance(kw["n"], int) for kw in n_results)
69+
70+
c_results = [kw for kw in results if "c" in kw]
71+
assert len(c_results) > 0
72+
assert all(isinstance(kw["c"], str) for kw in c_results)
73+

array_api_tests/test_creation_functions.py

+74-102
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from ._array_module import (asarray, arange, ceil, empty, empty_like, eye, full,
22
full_like, equal, all, linspace, ones, ones_like,
3-
zeros, zeros_like, isnan)
3+
zeros, zeros_like, isnan, float32)
44
from .array_helpers import (is_integer_dtype, dtype_ranges,
55
assert_exactly_equal, isintegral, is_float_dtype)
66
from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE,
77
shapes, sizes, sqrt_sizes, shared_dtypes,
8-
scalars, xps)
8+
scalars, xps, kwargs)
99

1010
from hypothesis import assume, given
1111
from hypothesis.strategies import integers, floats, one_of, none, booleans, just, shared
@@ -74,7 +74,7 @@ def test_arange(start, stop, step, dtype):
7474
def test_empty(shape, dtype):
7575
if dtype is None:
7676
a = empty(shape)
77-
assert is_float_dtype(a.dtype), "empty() should produce an array with the default floating point dtype"
77+
assert is_float_dtype(a.dtype), "empty() should returned an array with the default floating point dtype"
7878
else:
7979
a = empty(shape, dtype=dtype)
8080
assert a.dtype == dtype
@@ -85,23 +85,17 @@ def test_empty(shape, dtype):
8585

8686

8787
@given(
88-
x=xps.arrays(
89-
dtype=dtypes,
90-
shape=shapes,
91-
),
92-
dtype=optional_dtypes,
88+
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shapes),
89+
kw=kwargs(dtype=none() | xps.scalar_dtypes())
9390
)
94-
def test_empty_like(x, dtype):
95-
kwargs = {} if dtype is None else {'dtype': dtype}
96-
97-
x_like = empty_like(x, **kwargs)
98-
99-
if dtype is None:
100-
assert x_like.dtype == x.dtype, f"{x.dtype=!s}, but empty_like() did not produce a {x.dtype} array - instead was {x_like.dtype}"
91+
def test_empty_like(x, kw):
92+
out = empty_like(x, **kw)
93+
dtype = kw.get("dtype", None) or x.dtype
94+
if kw.get("dtype", None) is None:
95+
assert out.dtype == x.dtype, f"{x.dtype=!s}, but empty_like() returned an array with dtype {out.dtype}"
10196
else:
102-
assert x_like.dtype == dtype, f"{dtype=!s}, but empty_like() did not produce a {dtype} array - instead was {x_like.dtype}"
103-
104-
assert x_like.shape == x.shape, "empty_like() produced an array with an incorrect shape"
97+
assert out.dtype == dtype, f"{dtype=!s}, but empty_like() returned an array with dtype {out.dtype}"
98+
assert out.shape == x.shape, "empty_like() produced an array with an incorrect shape"
10599

106100

107101
# TODO: Use this method for all optional arguments
@@ -117,7 +111,7 @@ def test_eye(n_rows, n_cols, k, dtype):
117111
else:
118112
a = eye(n_rows, n_cols, **kwargs)
119113
if dtype is None:
120-
assert is_float_dtype(a.dtype), "eye() should produce an array with the default floating point dtype"
114+
assert is_float_dtype(a.dtype), "eye() should returned an array with the default floating point dtype"
121115
else:
122116
assert a.dtype == dtype, "eye() did not produce the correct dtype"
123117

@@ -142,7 +136,7 @@ def test_full(shape, fill_value, dtype):
142136

143137
if dtype is None:
144138
# TODO: Should it actually match the fill_value?
145-
# assert a.dtype in _floating_dtypes, "eye() should produce an array with the default floating point dtype"
139+
# assert a.dtype in _floating_dtypes, "eye() should returned an array with the default floating point dtype"
146140
pass
147141
else:
148142
assert a.dtype == dtype
@@ -157,21 +151,20 @@ def test_full(shape, fill_value, dtype):
157151
@given(
158152
x=xps.arrays(dtype=shared_dtypes, shape=shapes),
159153
fill_value=shared_dtypes.flatmap(xps.from_dtype),
160-
dtype=shared_optional_dtypes,
154+
kw=kwargs(dtype=none() | shared_dtypes),
161155
)
162-
def test_full_like(x, fill_value, dtype):
163-
x_like = full_like(x, fill_value, dtype=dtype)
164-
165-
if dtype is None:
166-
assert x_like.dtype == x.dtype, f"{x.dtype=!s}, but full_like() did not produce a {x.dtype} array - instead was {x_like.dtype}"
156+
def test_full_like(x, fill_value, kw):
157+
out = full_like(x, fill_value, **kw)
158+
dtype = kw.get("dtype", None) or x.dtype
159+
if kw.get("dtype", None) is None:
160+
assert out.dtype == x.dtype, f"{x.dtype=!s}, but full_like() returned an array with dtype {out.dtype}"
167161
else:
168-
assert x_like.dtype == dtype, f"{dtype=!s}, but full_like() did not produce a {dtype} array - instead was {x_like.dtype}"
169-
170-
assert x_like.shape == x.shape, "full_like() produced an array with incorrect shape"
171-
if is_float_dtype(x_like.dtype) and isnan(asarray(fill_value)):
172-
assert all(isnan(x_like)), "full_like() array did not equal the fill value"
162+
assert out.dtype == dtype, f"{dtype=!s}, but full_like() returned an array with dtype {out.dtype}"
163+
assert out.shape == x.shape, "{x.shape=}, but full_like() returned an array with shape {out.shape}"
164+
if is_float_dtype(dtype) and isnan(asarray(fill_value)):
165+
assert all(isnan(out)), "full_like() array did not equal the fill value"
173166
else:
174-
assert all(equal(x_like, asarray(fill_value, dtype=x_like.dtype))), "full_like() array did not equal the fill value"
167+
assert all(equal(out, asarray(fill_value, dtype=dtype))), "full_like() array did not equal the fill value"
175168

176169

177170
@given(scalars(shared_dtypes, finite=True),
@@ -192,11 +185,11 @@ def test_linspace(start, stop, num, dtype, endpoint):
192185
a = linspace(start, stop, num, **kwargs)
193186

194187
if dtype is None:
195-
assert is_float_dtype(a.dtype), "linspace() should produce an array with the default floating point dtype"
188+
assert is_float_dtype(a.dtype), "linspace() should returned an array with the default floating point dtype"
196189
else:
197190
assert a.dtype == dtype, "linspace() did not produce the correct dtype"
198191

199-
assert a.shape == (num,), "linspace() did not produce an array with the correct shape"
192+
assert a.shape == (num,), "linspace() did not returned an array with the correct shape"
200193

201194
if endpoint in [None, True]:
202195
if num > 1:
@@ -217,96 +210,75 @@ def test_linspace(start, stop, num, dtype, endpoint):
217210
# for i in range(1, num):
218211
# assert all(equal(a[i], full((), i*(stop - start)/n + start, dtype=dtype))), f"linspace() produced an array with an incorrect value at index {i}"
219212

220-
@given(shapes, one_of(none(), dtypes))
221-
def test_ones(shape, dtype):
222-
kwargs = {} if dtype is None else {'dtype': dtype}
223-
if dtype is None or is_float_dtype(dtype):
224-
ONE = 1.0
213+
214+
def make_one(dtype):
215+
if kwargs is None or is_float_dtype(dtype):
216+
return 1.0
225217
elif is_integer_dtype(dtype):
226-
ONE = 1
218+
return 1
227219
else:
228-
ONE = True
220+
return True
229221

230-
a = ones(shape, **kwargs)
231222

232-
if dtype is None:
233-
# TODO: Should it actually match the fill_value?
234-
# assert a.dtype in _floating_dtypes, "eye() should produce an array with the default floating point dtype"
235-
pass
223+
@given(shapes, kwargs(dtype=none() | xps.scalar_dtypes()))
224+
def test_ones(shape, kw):
225+
out = ones(shape, **kw)
226+
dtype = kw.get("dtype", None) or float32
227+
if kw.get("dtype", None) is None:
228+
assert is_float_dtype(out.dtype), "ones() returned an array with dtype {x.dtype}, but should be the default float dtype"
236229
else:
237-
assert a.dtype == dtype
238-
239-
assert a.shape == shape, "ones() produced an array with incorrect shape"
240-
assert all(equal(a, full((), ONE, **kwargs))), "ones() array did not equal 1"
230+
assert out.dtype == dtype, f"{dtype=!s}, but ones() returned an array with dtype {out.dtype}"
231+
assert out.shape == shape, "ones() produced an array with incorrect shape"
232+
assert all(equal(out, full((), make_one(dtype), dtype=dtype))), "ones() array did not equal 1"
241233

242234

243235
@given(
244236
x=xps.arrays(dtype=dtypes, shape=shapes),
245-
dtype=optional_dtypes,
237+
kw=kwargs(dtype=none() | xps.scalar_dtypes()),
246238
)
247-
def test_ones_like(x, dtype):
248-
kwargs = {} if dtype is None else {'dtype': dtype}
249-
if kwargs is None or is_float_dtype(x.dtype):
250-
ONE = 1.0
251-
elif is_integer_dtype(x.dtype):
252-
ONE = 1
239+
def test_ones_like(x, kw):
240+
out = ones_like(x, **kw)
241+
dtype = kw.get("dtype", None) or x.dtype
242+
if kw.get("dtype", None) is None:
243+
assert out.dtype == x.dtype, f"{x.dtype=!s}, but ones_like() returned an array with dtype {out.dtype}"
253244
else:
254-
ONE = True
245+
assert out.dtype == dtype, f"{dtype=!s}, but ones_like() returned an array with dtype {out.dtype}"
246+
assert out.shape == x.shape, "{x.shape=}, but ones_like() returned an array with shape {out.shape}"
247+
assert all(equal(out, full((), make_one(dtype), dtype=dtype))), "ones_like() array elements did not equal 1"
255248

256-
x_like = ones_like(x, **kwargs)
257249

258-
if dtype is None:
259-
assert x_like.dtype == x.dtype, f"{x.dtype=!s}, but ones_like() did not produce a {x.dtype} array - instead was {x_like.dtype}"
260-
else:
261-
assert x_like.dtype == dtype, f"{dtype=!s}, but ones_like() did not produce a {dtype} array - instead was {x_like.dtype}"
262-
263-
assert x_like.shape == x.shape, "ones_like() produced an array with an incorrect shape"
264-
assert all(equal(x_like, full((), ONE, dtype=x_like.dtype))), "ones_like() array did not equal 1"
265-
266-
267-
@given(shapes, one_of(none(), dtypes))
268-
def test_zeros(shape, dtype):
269-
kwargs = {} if dtype is None else {'dtype': dtype}
270-
if dtype is None or is_float_dtype(dtype):
271-
ZERO = 0.0
250+
def make_zero(dtype):
251+
if is_float_dtype(dtype):
252+
return 0.0
272253
elif is_integer_dtype(dtype):
273-
ZERO = 0
254+
return 0
274255
else:
275-
ZERO = False
256+
return False
276257

277-
a = zeros(shape, **kwargs)
278258

279-
if dtype is None:
280-
# TODO: Should it actually match the fill_value?
281-
# assert a.dtype in _floating_dtypes, "eye() should produce an array with the default floating point dtype"
282-
pass
259+
@given(shapes, kwargs(dtype=none() | xps.scalar_dtypes()))
260+
def test_zeros(shape, kw):
261+
out = zeros(shape, **kw)
262+
dtype = kw.get("dtype", None) or float32
263+
if kw.get("dtype", None) is None:
264+
assert is_float_dtype(out.dtype), "zeros() returned an array with dtype {out.dtype}, but should be the default float dtype"
283265
else:
284-
assert a.dtype == dtype
285-
286-
assert a.shape == shape, "zeros() produced an array with incorrect shape"
287-
assert all(equal(a, full((), ZERO, **kwargs))), "zeros() array did not equal 0"
266+
assert out.dtype == dtype, f"{dtype=!s}, but zeros() returned an array with dtype {out.dtype}"
267+
assert out.shape == shape, "zeros() produced an array with incorrect shape"
268+
assert all(equal(out, full((), make_zero(dtype), dtype=dtype))), "zeros() array did not equal 0"
288269

289270

290271
@given(
291272
x=xps.arrays(dtype=dtypes, shape=shapes),
292-
dtype=optional_dtypes,
273+
kw=kwargs(dtype=none() | xps.scalar_dtypes()),
293274
)
294-
def test_zeros_like(x, dtype):
295-
kwargs = {} if dtype is None else {'dtype': dtype}
296-
if dtype is None or is_float_dtype(x.dtype):
297-
ZERO = 0.0
298-
elif is_integer_dtype(x.dtype):
299-
ZERO = 0
300-
else:
301-
ZERO = False
302-
303-
x_like = zeros_like(x, **kwargs)
304-
305-
if dtype is None:
306-
assert x_like.dtype == x.dtype, f"{x.dtype=!s}, but zeros_like() did not produce a {x.dtype} array - instead was {x_like.dtype}"
275+
def test_zeros_like(x, kw):
276+
out = zeros_like(x, **kw)
277+
dtype = kw.get("dtype", None) or x.dtype
278+
if kw.get("dtype", None) is None:
279+
assert out.dtype == x.dtype, f"{x.dtype=!s}, but zeros_like() returned an array with dtype {out.dtype}"
307280
else:
308-
assert x_like.dtype == dtype, f"{dtype=!s}, but zeros_like() did not produce a {dtype} array - instead was {x_like.dtype}"
309-
310-
assert x_like.shape == x.shape, "zeros_like() produced an array with an incorrect shape"
311-
assert all(equal(x_like, full((), ZERO, dtype=x_like.dtype))), "zeros_like() array did not equal 0"
281+
assert out.dtype == dtype, f"{dtype=!s}, but zeros_like() returned an array with dtype {out.dtype}"
282+
assert out.shape == x.shape, "{x.shape=}, but zeros_like() returned an array with shape {out.shape}"
283+
assert all(equal(out, full((), make_zero(dtype), dtype=out.dtype))), "zeros_like() array elements did not all equal 0"
312284

0 commit comments

Comments
 (0)