Skip to content

Commit 564279a

Browse files
committed
test_take
1 parent a533680 commit 564279a

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

array_api_tests/_array_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __repr__(self):
6262
]
6363
_constants = ["e", "inf", "nan", "pi"]
6464
_funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]
65+
_funcs += ["take"] # TODO: bump spec and update array-api-tests to new spec layout
6566
_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS
6667

6768
for attr in _top_level_attrs:
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import pytest
2+
from hypothesis import given, note
3+
from hypothesis import strategies as st
4+
5+
from . import _array_module as xp
6+
from . import api_version
7+
from . import dtype_helpers as dh
8+
from . import hypothesis_helpers as hh
9+
from . import pytest_helpers as ph
10+
from . import shape_helpers as sh
11+
from . import xps
12+
13+
pytestmark = pytest.mark.ci
14+
15+
16+
if api_version >= "2022.12":
17+
18+
@given(
19+
x=xps.arrays(xps.scalar_dtypes(), hh.shapes(min_dims=1, min_side=1)),
20+
data=st.data(),
21+
)
22+
def test_take(x, data):
23+
# TODO:
24+
# * negative axis
25+
# * negative indices
26+
# * different dtypes for indices
27+
axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis")
28+
_indices = data.draw(
29+
st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True),
30+
label="_indices",
31+
)
32+
indices = xp.asarray(_indices, dtype=dh.default_int)
33+
note(f"{indices=}")
34+
35+
out = xp.take(x, indices, axis=axis)
36+
37+
ph.assert_dtype("take", x.dtype, out.dtype)
38+
ph.assert_shape(
39+
"take",
40+
out.shape,
41+
x.shape[:axis] + (len(_indices),) + x.shape[axis + 1 :],
42+
x=x,
43+
indices=indices,
44+
axis=axis,
45+
)
46+
out_indices = sh.ndindex(out.shape)
47+
axis_indices = list(sh.axis_ndindex(x.shape, axis))
48+
for axis_idx in axis_indices:
49+
f_axis_idx = sh.fmt_idx("x", axis_idx)
50+
for i in _indices:
51+
f_take_idx = sh.fmt_idx(f_axis_idx, i)
52+
indexed_x = x[axis_idx][i]
53+
for at_idx in sh.ndindex(indexed_x.shape):
54+
out_idx = next(out_indices)
55+
ph.assert_0d_equals(
56+
"take",
57+
sh.fmt_idx(f_take_idx, at_idx),
58+
indexed_x[at_idx],
59+
sh.fmt_idx("out", out_idx),
60+
out[out_idx],
61+
)
62+
# sanity check
63+
with pytest.raises(StopIteration):
64+
next(out_indices)

0 commit comments

Comments
 (0)