Skip to content

Commit a398866

Browse files
authored
Merge pull request #173 from honno/test-take
`test_take`
2 parents af2e299 + 675fa6c commit a398866

File tree

3 files changed

+85
-8
lines changed

3 files changed

+85
-8
lines changed

array_api_tests/_array_module.py

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __repr__(self):
6262
]
6363
_constants = ["e", "inf", "nan", "pi"]
6464
_funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]
65+
_funcs += ["take"] # TODO: bump spec and update array-api-tests to new spec layout
6566
_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS
6667

6768
for attr in _top_level_attrs:

array_api_tests/dtype_helpers.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union
66
from warnings import warn
77

8-
from . import api_version
98
from . import _array_module as xp
9+
from . import api_version
1010
from ._array_module import _UndefinedStub
11+
from ._array_module import mod as _xp
1112
from .stubs import name_to_func
1213
from .typing import DataType, ScalarType
1314

@@ -88,6 +89,12 @@ def __repr__(self):
8889
return f"EqualityMapping({self})"
8990

9091

92+
def _filter_stubs(*args):
93+
for a in args:
94+
if not isinstance(a, _UndefinedStub):
95+
yield a
96+
97+
9198
_uint_names = ("uint8", "uint16", "uint32", "uint64")
9299
_int_names = ("int8", "int16", "int32", "int64")
93100
_float_names = ("float32", "float64")
@@ -113,7 +120,14 @@ def __repr__(self):
113120
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
114121

115122

116-
dtype_to_name = EqualityMapping([(getattr(xp, name), name) for name in _dtype_names])
123+
_dtype_name_pairs = []
124+
for name in _dtype_names:
125+
try:
126+
dtype = getattr(_xp, name)
127+
except AttributeError:
128+
continue
129+
_dtype_name_pairs.append((dtype, name))
130+
dtype_to_name = EqualityMapping(_dtype_name_pairs)
117131

118132

119133
dtype_to_scalars = EqualityMapping(
@@ -173,12 +187,13 @@ class MinMax(NamedTuple):
173187
]
174188
)
175189

190+
176191
dtype_nbits = EqualityMapping(
177-
[(d, 8) for d in [xp.int8, xp.uint8]]
178-
+ [(d, 16) for d in [xp.int16, xp.uint16]]
179-
+ [(d, 32) for d in [xp.int32, xp.uint32, xp.float32]]
180-
+ [(d, 64) for d in [xp.int64, xp.uint64, xp.float64, xp.complex64]]
181-
+ [(xp.complex128, 128)]
192+
[(d, 8) for d in _filter_stubs(xp.int8, xp.uint8)]
193+
+ [(d, 16) for d in _filter_stubs(xp.int16, xp.uint16)]
194+
+ [(d, 32) for d in _filter_stubs(xp.int32, xp.uint32, xp.float32)]
195+
+ [(d, 64) for d in _filter_stubs(xp.int64, xp.uint64, xp.float64, xp.complex64)]
196+
+ [(d, 128) for d in _filter_stubs(xp.complex128)]
182197
)
183198

184199

@@ -265,7 +280,6 @@ class MinMax(NamedTuple):
265280
((xp.complex64, xp.complex64), xp.complex64),
266281
((xp.complex64, xp.complex128), xp.complex128),
267282
((xp.complex128, xp.complex128), xp.complex128),
268-
269283
]
270284
_numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions]
271285
_promotion_table = list(set(_numeric_promotions))
+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import pytest
2+
from hypothesis import given, note
3+
from hypothesis import strategies as st
4+
5+
from . import _array_module as xp
6+
from . import dtype_helpers as dh
7+
from . import hypothesis_helpers as hh
8+
from . import pytest_helpers as ph
9+
from . import shape_helpers as sh
10+
from . import xps
11+
12+
pytestmark = pytest.mark.ci
13+
14+
15+
@pytest.mark.min_version("2022.12")
16+
@given(
17+
x=xps.arrays(xps.scalar_dtypes(), hh.shapes(min_dims=1, min_side=1)),
18+
data=st.data(),
19+
)
20+
def test_take(x, data):
21+
# TODO:
22+
# * negative axis
23+
# * negative indices
24+
# * different dtypes for indices
25+
axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis")
26+
_indices = data.draw(
27+
st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True),
28+
label="_indices",
29+
)
30+
indices = xp.asarray(_indices, dtype=dh.default_int)
31+
note(f"{indices=}")
32+
33+
out = xp.take(x, indices, axis=axis)
34+
35+
ph.assert_dtype("take", x.dtype, out.dtype)
36+
ph.assert_shape(
37+
"take",
38+
out.shape,
39+
x.shape[:axis] + (len(_indices),) + x.shape[axis + 1 :],
40+
x=x,
41+
indices=indices,
42+
axis=axis,
43+
)
44+
out_indices = sh.ndindex(out.shape)
45+
axis_indices = list(sh.axis_ndindex(x.shape, axis))
46+
for axis_idx in axis_indices:
47+
f_axis_idx = sh.fmt_idx("x", axis_idx)
48+
for i in _indices:
49+
f_take_idx = sh.fmt_idx(f_axis_idx, i)
50+
indexed_x = x[axis_idx][i]
51+
for at_idx in sh.ndindex(indexed_x.shape):
52+
out_idx = next(out_indices)
53+
ph.assert_0d_equals(
54+
"take",
55+
sh.fmt_idx(f_take_idx, at_idx),
56+
indexed_x[at_idx],
57+
sh.fmt_idx("out", out_idx),
58+
out[out_idx],
59+
)
60+
# sanity check
61+
with pytest.raises(StopIteration):
62+
next(out_indices)

0 commit comments

Comments
 (0)