diff --git a/array_api_tests/_array_module.py b/array_api_tests/_array_module.py index b4aaf76c..4ab584b7 100644 --- a/array_api_tests/_array_module.py +++ b/array_api_tests/_array_module.py @@ -62,6 +62,7 @@ def __repr__(self): ] _constants = ["e", "inf", "nan", "pi"] _funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs] +_funcs += ["take"] # TODO: bump spec and update array-api-tests to new spec layout _top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS for attr in _top_level_attrs: diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index fb167168..8fb8fa7e 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -5,9 +5,10 @@ from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union from warnings import warn -from . import api_version from . import _array_module as xp +from . import api_version from ._array_module import _UndefinedStub +from ._array_module import mod as _xp from .stubs import name_to_func from .typing import DataType, ScalarType @@ -88,6 +89,12 @@ def __repr__(self): return f"EqualityMapping({self})" +def _filter_stubs(*args): + for a in args: + if not isinstance(a, _UndefinedStub): + yield a + + _uint_names = ("uint8", "uint16", "uint32", "uint64") _int_names = ("int8", "int16", "int32", "int64") _float_names = ("float32", "float64") @@ -113,7 +120,14 @@ def __repr__(self): bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes -dtype_to_name = EqualityMapping([(getattr(xp, name), name) for name in _dtype_names]) +_dtype_name_pairs = [] +for name in _dtype_names: + try: + dtype = getattr(_xp, name) + except AttributeError: + continue + _dtype_name_pairs.append((dtype, name)) +dtype_to_name = EqualityMapping(_dtype_name_pairs) dtype_to_scalars = EqualityMapping( @@ -173,12 +187,13 @@ class MinMax(NamedTuple): ] ) + dtype_nbits = EqualityMapping( - [(d, 8) for d in [xp.int8, xp.uint8]] - + [(d, 16) for d in [xp.int16, xp.uint16]] - + [(d, 32) for d in [xp.int32, xp.uint32, xp.float32]] - + [(d, 64) for d in [xp.int64, xp.uint64, xp.float64, xp.complex64]] - + [(xp.complex128, 128)] + [(d, 8) for d in _filter_stubs(xp.int8, xp.uint8)] + + [(d, 16) for d in _filter_stubs(xp.int16, xp.uint16)] + + [(d, 32) for d in _filter_stubs(xp.int32, xp.uint32, xp.float32)] + + [(d, 64) for d in _filter_stubs(xp.int64, xp.uint64, xp.float64, xp.complex64)] + + [(d, 128) for d in _filter_stubs(xp.complex128)] ) @@ -265,7 +280,6 @@ class MinMax(NamedTuple): ((xp.complex64, xp.complex64), xp.complex64), ((xp.complex64, xp.complex128), xp.complex128), ((xp.complex128, xp.complex128), xp.complex128), - ] _numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions] _promotion_table = list(set(_numeric_promotions)) diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py new file mode 100644 index 00000000..6aa80fed --- /dev/null +++ b/array_api_tests/test_indexing_functions.py @@ -0,0 +1,62 @@ +import pytest +from hypothesis import given, note +from hypothesis import strategies as st + +from . import _array_module as xp +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph +from . import shape_helpers as sh +from . import xps + +pytestmark = pytest.mark.ci + + +@pytest.mark.min_version("2022.12") +@given( + x=xps.arrays(xps.scalar_dtypes(), hh.shapes(min_dims=1, min_side=1)), + data=st.data(), +) +def test_take(x, data): + # TODO: + # * negative axis + # * negative indices + # * different dtypes for indices + axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis") + _indices = data.draw( + st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True), + label="_indices", + ) + indices = xp.asarray(_indices, dtype=dh.default_int) + note(f"{indices=}") + + out = xp.take(x, indices, axis=axis) + + ph.assert_dtype("take", x.dtype, out.dtype) + ph.assert_shape( + "take", + out.shape, + x.shape[:axis] + (len(_indices),) + x.shape[axis + 1 :], + x=x, + indices=indices, + axis=axis, + ) + out_indices = sh.ndindex(out.shape) + axis_indices = list(sh.axis_ndindex(x.shape, axis)) + for axis_idx in axis_indices: + f_axis_idx = sh.fmt_idx("x", axis_idx) + for i in _indices: + f_take_idx = sh.fmt_idx(f_axis_idx, i) + indexed_x = x[axis_idx][i] + for at_idx in sh.ndindex(indexed_x.shape): + out_idx = next(out_indices) + ph.assert_0d_equals( + "take", + sh.fmt_idx(f_take_idx, at_idx), + indexed_x[at_idx], + sh.fmt_idx("out", out_idx), + out[out_idx], + ) + # sanity check + with pytest.raises(StopIteration): + next(out_indices)