Skip to content

Commit 0f63fab

Browse files
authored
Merge pull request #89 from honno/values-testing
Increase coverage and other niceties for the operator/elementwise tests
2 parents 2e78904 + 3c85cae commit 0f63fab

20 files changed

+911
-1101
lines changed

array_api_tests/array_helpers.py

-11
Original file line numberDiff line numberDiff line change
@@ -306,14 +306,3 @@ def same_sign(x, y):
306306
def assert_same_sign(x, y):
307307
assert all(same_sign(x, y)), "The input arrays do not have the same sign"
308308

309-
def int_to_dtype(x, n, signed):
310-
"""
311-
Convert the Python integer x into an n bit signed or unsigned number.
312-
"""
313-
mask = (1 << n) - 1
314-
x &= mask
315-
if signed:
316-
highest_bit = 1 << (n-1)
317-
if x & highest_bit:
318-
x = -((~x & mask) + 1)
319-
return x

array_api_tests/hypothesis_helpers.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from ._array_module import _UndefinedStub
1717
from ._array_module import bool as bool_dtype
1818
from ._array_module import broadcast_to, eye, float32, float64, full
19-
from .algos import broadcast_shapes
2019
from .function_stubs import elementwise_functions
2120
from .pytest_helpers import nargs
2221
from .typing import Array, DataType, Shape
@@ -243,7 +242,7 @@ def two_broadcastable_shapes(draw):
243242
broadcast to shape1.
244243
"""
245244
shape1, shape2 = draw(two_mutually_broadcastable_shapes)
246-
assume(broadcast_shapes(shape1, shape2) == shape1)
245+
assume(sh.broadcast_shapes(shape1, shape2) == shape1)
247246
return (shape1, shape2)
248247

249248
sizes = integers(0, MAX_ARRAY_SIZE)
@@ -370,6 +369,9 @@ def two_mutual_arrays(
370369
) -> Tuple[SearchStrategy[Array], SearchStrategy[Array]]:
371370
if not isinstance(dtypes, Sequence):
372371
raise TypeError(f"{dtypes=} not a sequence")
372+
if FILTER_UNDEFINED_DTYPES:
373+
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
374+
assert len(dtypes) > 0 # sanity check
373375
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes))
374376
mutual_shapes = shared(two_shapes)
375377
arrays1 = xps.arrays(
+1-15
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
1-
from hypothesis import given, assume
2-
from hypothesis.strategies import integers
3-
4-
from ..array_helpers import exactly_equal, notequal, int_to_dtype
5-
from ..hypothesis_helpers import integer_dtypes
6-
from ..dtype_helpers import dtype_nbits, dtype_signed
71
from .. import _array_module as xp
2+
from ..array_helpers import exactly_equal, notequal
83

94
# TODO: These meta-tests currently only work with NumPy
105

@@ -22,12 +17,3 @@ def test_notequal():
2217
res = xp.asarray([False, True, False, False, False, True, False, True])
2318
assert xp.all(xp.equal(notequal(a, b), res))
2419

25-
@given(integers(), integer_dtypes)
26-
def test_int_to_dtype(x, dtype):
27-
n = dtype_nbits[dtype]
28-
signed = dtype_signed[dtype]
29-
try:
30-
d = xp.asarray(x, dtype=dtype)
31-
except OverflowError:
32-
assume(False)
33-
assert int_to_dtype(x, n, signed) == d

array_api_tests/meta/test_broadcasting.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66

7-
from ..algos import BroadcastError, _broadcast_shapes
7+
from .. import shape_helpers as sh
88

99

1010
@pytest.mark.parametrize(
@@ -19,7 +19,7 @@
1919
],
2020
)
2121
def test_broadcast_shapes(shape1, shape2, expected):
22-
assert _broadcast_shapes(shape1, shape2) == expected
22+
assert sh._broadcast_shapes(shape1, shape2) == expected
2323

2424

2525
@pytest.mark.parametrize(
@@ -31,5 +31,5 @@ def test_broadcast_shapes(shape1, shape2, expected):
3131
],
3232
)
3333
def test_broadcast_shapes_fails_on_bad_shapes(shape1, shape2):
34-
with pytest.raises(BroadcastError):
35-
_broadcast_shapes(shape1, shape2)
34+
with pytest.raises(sh.BroadcastError):
35+
sh._broadcast_shapes(shape1, shape2)

array_api_tests/meta/test_hypothesis_helpers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from .. import array_helpers as ah
99
from .. import dtype_helpers as dh
1010
from .. import hypothesis_helpers as hh
11+
from .. import shape_helpers as sh
1112
from .. import xps
1213
from .._array_module import _UndefinedStub
13-
from ..algos import broadcast_shapes
1414

1515
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes)
1616
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
@@ -62,7 +62,7 @@ def test_two_mutually_broadcastable_shapes(pair):
6262
def test_two_broadcastable_shapes(pair):
6363
for shape in pair:
6464
assert valid_shape(shape)
65-
assert broadcast_shapes(pair[0], pair[1]) == pair[0]
65+
assert sh.broadcast_shapes(pair[0], pair[1]) == pair[0]
6666

6767

6868
@given(*hh.two_mutual_arrays())

array_api_tests/meta/test_pytest_helpers.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66

77
def test_assert_dtype():
8-
ph.assert_dtype("promoted_func", (xp.uint8, xp.int8), xp.int16)
8+
ph.assert_dtype("promoted_func", [xp.uint8, xp.int8], xp.int16)
99
with raises(AssertionError):
10-
ph.assert_dtype("bad_func", (xp.uint8, xp.int8), xp.float32)
11-
ph.assert_dtype("bool_func", (xp.uint8, xp.int8), xp.bool, xp.bool)
12-
ph.assert_dtype("single_promoted_func", (xp.uint8,), xp.uint8)
13-
ph.assert_dtype("single_bool_func", (xp.uint8,), xp.bool, xp.bool)
10+
ph.assert_dtype("bad_func", [xp.uint8, xp.int8], xp.float32)
11+
ph.assert_dtype("bool_func", [xp.uint8, xp.int8], xp.bool, xp.bool)
12+
ph.assert_dtype("single_promoted_func", [xp.uint8], xp.uint8)
13+
ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool)

array_api_tests/meta/test_utils.py

+33
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import pytest
2+
from hypothesis import given, reject
3+
from hypothesis import strategies as st
24

5+
from .. import _array_module as xp
6+
from .. import xps
37
from .. import shape_helpers as sh
48
from ..test_creation_functions import frange
59
from ..test_manipulation_functions import roll_ndindex
10+
from ..test_operators_and_elementwise_functions import mock_int_dtype
611
from ..test_signatures import extension_module
712

813

@@ -82,3 +87,31 @@ def test_axes_ndindex(shape, axes, expected):
8287
)
8388
def test_roll_ndindex(shape, shifts, axes, expected):
8489
assert list(roll_ndindex(shape, shifts, axes)) == expected
90+
91+
92+
@pytest.mark.parametrize(
93+
"idx, expected",
94+
[
95+
((), "x"),
96+
(42, "x[42]"),
97+
((42,), "x[42]"),
98+
(slice(None, 2), "x[:2]"),
99+
(slice(2, None), "x[2:]"),
100+
(slice(0, 2), "x[0:2]"),
101+
(slice(0, 2, -1), "x[0:2:-1]"),
102+
(slice(None, None, -1), "x[::-1]"),
103+
(slice(None, None), "x[:]"),
104+
(..., "x[...]"),
105+
],
106+
)
107+
def test_fmt_idx(idx, expected):
108+
assert sh.fmt_idx("x", idx) == expected
109+
110+
111+
@given(x=st.integers(), dtype=xps.unsigned_integer_dtypes() | xps.integer_dtypes())
112+
def test_int_to_dtype(x, dtype):
113+
try:
114+
d = xp.asarray(x, dtype=dtype)
115+
except OverflowError:
116+
reject()
117+
assert mock_int_dtype(x, dtype) == d

array_api_tests/pytest_helpers.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import math
22
from inspect import getfullargspec
3-
from typing import Any, Dict, Optional, Tuple, Union
3+
from typing import Any, Dict, Optional, Sequence, Tuple, Union
44

55
from . import _array_module as xp
66
from . import array_helpers as ah
77
from . import dtype_helpers as dh
88
from . import function_stubs
9-
from .algos import broadcast_shapes
9+
from . import shape_helpers as sh
1010
from .typing import Array, DataType, Scalar, ScalarType, Shape
1111

1212
__all__ = [
@@ -71,15 +71,14 @@ def fmt_kw(kw: Dict[str, Any]) -> str:
7171

7272
def assert_dtype(
7373
func_name: str,
74-
in_dtypes: Union[DataType, Tuple[DataType, ...]],
74+
in_dtype: Union[DataType, Sequence[DataType]],
7575
out_dtype: DataType,
7676
expected: Optional[DataType] = None,
7777
*,
7878
repr_name: str = "out.dtype",
7979
):
80-
if not isinstance(in_dtypes, tuple):
81-
in_dtypes = (in_dtypes,)
82-
f_in_dtypes = dh.fmt_types(in_dtypes)
80+
in_dtypes = in_dtype if isinstance(in_dtype, Sequence) else [in_dtype]
81+
f_in_dtypes = dh.fmt_types(tuple(in_dtypes))
8382
f_out_dtype = dh.dtype_to_name[out_dtype]
8483
if expected is None:
8584
expected = dh.result_type(*in_dtypes)
@@ -150,7 +149,7 @@ def assert_shape(
150149

151150
def assert_result_shape(
152151
func_name: str,
153-
in_shapes: Tuple[Shape],
152+
in_shapes: Sequence[Shape],
154153
out_shape: Shape,
155154
/,
156155
expected: Optional[Shape] = None,
@@ -159,7 +158,7 @@ def assert_result_shape(
159158
**kw,
160159
):
161160
if expected is None:
162-
expected = broadcast_shapes(*in_shapes)
161+
expected = sh.broadcast_shapes(*in_shapes)
163162
f_in_shapes = " . ".join(str(s) for s in in_shapes)
164163
f_sig = f" {f_in_shapes} "
165164
if kw:

array_api_tests/shape_helpers.py

+105-9
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,67 @@
22
from itertools import product
33
from typing import Iterator, List, Optional, Tuple, Union
44

5-
from .typing import Scalar, Shape
5+
from ndindex import iter_indices as _iter_indices
66

7-
__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex", "reshape"]
7+
from .typing import AtomicIndex, Index, Scalar, Shape
8+
9+
__all__ = [
10+
"broadcast_shapes",
11+
"normalise_axis",
12+
"ndindex",
13+
"axis_ndindex",
14+
"axes_ndindex",
15+
"reshape",
16+
"fmt_idx",
17+
]
18+
19+
20+
class BroadcastError(ValueError):
21+
"""Shapes do not broadcast with eachother"""
22+
23+
24+
def _broadcast_shapes(shape1: Shape, shape2: Shape) -> Shape:
25+
"""Broadcasts `shape1` and `shape2`"""
26+
N1 = len(shape1)
27+
N2 = len(shape2)
28+
N = max(N1, N2)
29+
shape = [None for _ in range(N)]
30+
i = N - 1
31+
while i >= 0:
32+
n1 = N1 - N + i
33+
if N1 - N + i >= 0:
34+
d1 = shape1[n1]
35+
else:
36+
d1 = 1
37+
n2 = N2 - N + i
38+
if N2 - N + i >= 0:
39+
d2 = shape2[n2]
40+
else:
41+
d2 = 1
42+
43+
if d1 == 1:
44+
shape[i] = d2
45+
elif d2 == 1:
46+
shape[i] = d1
47+
elif d1 == d2:
48+
shape[i] = d1
49+
else:
50+
raise BroadcastError()
51+
52+
i = i - 1
53+
54+
return tuple(shape)
55+
56+
57+
def broadcast_shapes(*shapes: Shape):
58+
if len(shapes) == 0:
59+
raise ValueError("shapes=[] must be non-empty")
60+
elif len(shapes) == 1:
61+
return shapes[0]
62+
result = _broadcast_shapes(shapes[0], shapes[1])
63+
for i in range(2, len(shapes)):
64+
result = _broadcast_shapes(result, shapes[i])
65+
return result
866

967

1068
def normalise_axis(
@@ -17,13 +75,21 @@ def normalise_axis(
1775
return axes
1876

1977

20-
def ndindex(shape):
21-
"""Iterator of n-D indices to an array
78+
def ndindex(shape: Shape) -> Iterator[Index]:
79+
"""Yield every index of a shape"""
80+
return (indices[0] for indices in iter_indices(shape))
81+
2282

23-
Yields tuples of integers to index every element of an array of shape
24-
`shape`. Same as np.ndindex().
25-
"""
26-
return product(*[range(i) for i in shape])
83+
def iter_indices(
84+
*shapes: Shape, skip_axes: Tuple[int, ...] = ()
85+
) -> Iterator[Tuple[Index, ...]]:
86+
"""Wrapper for ndindex.iter_indices()"""
87+
# Prevent iterations if any shape has 0-sides
88+
for shape in shapes:
89+
if 0 in shape:
90+
return
91+
for indices in _iter_indices(*shapes, skip_axes=skip_axes):
92+
yield tuple(i.raw for i in indices) # type: ignore
2793

2894

2995
def axis_ndindex(
@@ -60,7 +126,7 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]:
60126
yield list(indices)
61127

62128

63-
def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]]:
129+
def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List]:
64130
"""Reshape a flat sequence"""
65131
if any(s == 0 for s in shape):
66132
raise ValueError(
@@ -75,3 +141,33 @@ def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]]
75141
size = len(flat_seq)
76142
n = math.prod(shape[1:])
77143
return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)]
144+
145+
146+
def fmt_i(i: AtomicIndex) -> str:
147+
if isinstance(i, int):
148+
return str(i)
149+
elif isinstance(i, slice):
150+
res = ""
151+
if i.start is not None:
152+
res += str(i.start)
153+
res += ":"
154+
if i.stop is not None:
155+
res += str(i.stop)
156+
if i.step is not None:
157+
res += f":{i.step}"
158+
return res
159+
else:
160+
return "..."
161+
162+
163+
def fmt_idx(sym: str, idx: Index) -> str:
164+
if idx == ():
165+
return sym
166+
res = f"{sym}["
167+
_idx = idx if isinstance(idx, tuple) else (idx,)
168+
if len(_idx) == 1:
169+
res += fmt_i(_idx[0])
170+
else:
171+
res += ", ".join(fmt_i(i) for i in _idx)
172+
res += "]"
173+
return res

0 commit comments

Comments
 (0)