Skip to content

Commit f09c939

Browse files
authored
Merge pull request #57 from honno/reorg
Indexing tests
2 parents 69c2dab + 043e437 commit f09c939

File tree

5 files changed

+152
-184
lines changed

5 files changed

+152
-184
lines changed

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
pytest
2-
hypothesis>=6.30.0
2+
hypothesis>=6.31.1
33
regex
44
removestar

xptests/pytest_helpers.py

+18
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"assert_result_shape",
2424
"assert_keepdimable_shape",
2525
"assert_fill",
26+
"assert_array",
2627
]
2728

2829

@@ -226,3 +227,20 @@ def assert_fill(
226227
assert ah.all(ah.isnan(out)), msg
227228
else:
228229
assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), msg
230+
231+
232+
def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
233+
assert_dtype(func_name, out.dtype, expected.dtype, **kw)
234+
assert_shape(func_name, out.shape, expected.shape, **kw)
235+
msg = f"out not as expected [{func_name}({fmt_kw(kw)})]\n{out=}\n{expected=}"
236+
if dh.is_float_dtype(out.dtype):
237+
neg_zeros = expected == -0.0
238+
assert xp.all((out == -0.0) == neg_zeros), msg
239+
pos_zeros = expected == +0.0
240+
assert xp.all((out == +0.0) == pos_zeros), msg
241+
nans = xp.isnan(expected)
242+
assert xp.all(xp.isnan(out) == nans), msg
243+
mask = ~(neg_zeros | pos_zeros | nans)
244+
assert xp.all(out[mask] == expected[mask]), msg
245+
else:
246+
assert xp.all(out == expected), msg

xptests/test_array2scalar.py

-47
This file was deleted.

xptests/test_array_object.py

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import math
2+
from itertools import product
3+
from typing import Sequence, Union, get_args
4+
5+
import pytest
6+
from hypothesis import assume, given, note
7+
from hypothesis import strategies as st
8+
9+
from . import _array_module as xp
10+
from . import dtype_helpers as dh
11+
from . import hypothesis_helpers as hh
12+
from . import pytest_helpers as ph
13+
from . import xps
14+
from .typing import DataType, Param, Scalar, ScalarType, Shape
15+
16+
17+
def reshape(
18+
flat_seq: Sequence[Scalar], shape: Shape
19+
) -> Union[Scalar, Sequence[Scalar]]:
20+
"""Reshape a flat sequence"""
21+
if len(shape) == 0:
22+
assert len(flat_seq) == 1 # sanity check
23+
return flat_seq[0]
24+
elif len(shape) == 1:
25+
return flat_seq
26+
size = len(flat_seq)
27+
n = math.prod(shape[1:])
28+
return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)]
29+
30+
31+
@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
32+
def test_getitem(shape, data):
33+
size = math.prod(shape)
34+
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
35+
obj = data.draw(
36+
st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
37+
lambda l: reshape(l, shape)
38+
),
39+
label="obj",
40+
)
41+
x = xp.asarray(obj, dtype=dtype)
42+
note(f"{x=}")
43+
key = data.draw(xps.indices(shape=shape), label="key")
44+
45+
out = x[key]
46+
47+
ph.assert_dtype("__getitem__", x.dtype, out.dtype)
48+
_key = tuple(key) if isinstance(key, tuple) else (key,)
49+
if Ellipsis in _key:
50+
start_a = _key.index(Ellipsis)
51+
stop_a = start_a + (len(shape) - (len(_key) - 1))
52+
slices = tuple(slice(None, None) for _ in range(start_a, stop_a))
53+
_key = _key[:start_a] + slices + _key[start_a + 1 :]
54+
axes_indices = []
55+
out_shape = []
56+
for a, i in enumerate(_key):
57+
if isinstance(i, int):
58+
axes_indices.append([i])
59+
else:
60+
side = shape[a]
61+
indices = range(side)[i]
62+
axes_indices.append(indices)
63+
out_shape.append(len(indices))
64+
out_shape = tuple(out_shape)
65+
ph.assert_shape("__getitem__", out.shape, out_shape)
66+
assume(all(len(indices) > 0 for indices in axes_indices))
67+
out_obj = []
68+
for idx in product(*axes_indices):
69+
val = obj
70+
for i in idx:
71+
val = val[i]
72+
out_obj.append(val)
73+
out_obj = reshape(out_obj, out_shape)
74+
expected = xp.asarray(out_obj, dtype=dtype)
75+
ph.assert_array("__getitem__", out, expected)
76+
77+
78+
@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
79+
def test_setitem(shape, data):
80+
size = math.prod(shape)
81+
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
82+
obj = data.draw(
83+
st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
84+
lambda l: reshape(l, shape)
85+
),
86+
label="obj",
87+
)
88+
x = xp.asarray(obj, dtype=dtype)
89+
note(f"{x=}")
90+
key = data.draw(xps.indices(shape=shape, max_dims=0), label="key")
91+
value = data.draw(
92+
xps.from_dtype(dtype) | xps.arrays(dtype=dtype, shape=()), label="value"
93+
)
94+
95+
res = xp.asarray(x, copy=True)
96+
res[key] = value
97+
98+
ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype")
99+
ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.shape")
100+
if isinstance(value, get_args(Scalar)):
101+
msg = f"x[{key}]={res[key]!r}, but should be {value=} [__setitem__()]"
102+
if math.isnan(value):
103+
assert xp.isnan(res[key]), msg
104+
else:
105+
assert res[key] == value, msg
106+
else:
107+
ph.assert_0d_equals("__setitem__", "value", value, f"x[{key}]", res[key])
108+
109+
110+
# TODO: test boolean indexing
111+
112+
113+
def make_param(method_name: str, dtype: DataType, stype: ScalarType) -> Param:
114+
return pytest.param(
115+
method_name, dtype, stype, id=f"{method_name}({dh.dtype_to_name[dtype]})"
116+
)
117+
118+
119+
@pytest.mark.parametrize(
120+
"method_name, dtype, stype",
121+
[make_param("__bool__", xp.bool, bool)]
122+
+ [make_param("__int__", d, int) for d in dh.all_int_dtypes]
123+
+ [make_param("__index__", d, int) for d in dh.all_int_dtypes]
124+
+ [make_param("__float__", d, float) for d in dh.float_dtypes],
125+
)
126+
@given(data=st.data())
127+
def test_duck_typing(method_name, dtype, stype, data):
128+
x = data.draw(xps.arrays(dtype, shape=()), label="x")
129+
method = getattr(x, method_name)
130+
out = method()
131+
assert isinstance(
132+
out, stype
133+
), f"{method_name}({x})={out}, which is not a {stype.__name__} scalar"

xptests/test_indexing.py

-136
This file was deleted.

0 commit comments

Comments
 (0)