diff --git a/.github/workflows/numpy.yml b/.github/workflows/numpy.yml index 91e16828..74eeccf4 100644 --- a/.github/workflows/numpy.yml +++ b/.github/workflows/numpy.yml @@ -42,6 +42,9 @@ jobs: # https://github.com/numpy/numpy/issues/20326#issuecomment-1012380448 array_api_tests/test_set_functions.py + # https://github.com/numpy/numpy/issues/21373 + array_api_tests/test_array_object.py::test_getitem + # missing copy arg array_api_tests/test_signatures.py::test_func_signature[reshape] diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index be03f5fa..50db7e51 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -31,27 +31,35 @@ def test_getitem(shape, data): obj = data.draw(scalar_objects(dtype, shape), label="obj") x = xp.asarray(obj, dtype=dtype) note(f"{x=}") - key = data.draw(xps.indices(shape=shape), label="key") + key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key") out = x[key] ph.assert_dtype("__getitem__", x.dtype, out.dtype) _key = tuple(key) if isinstance(key, tuple) else (key,) if Ellipsis in _key: - start_a = _key.index(Ellipsis) - stop_a = start_a + (len(shape) - (len(_key) - 1)) - slices = tuple(slice(None, None) for _ in range(start_a, stop_a)) - _key = _key[:start_a] + slices + _key[start_a + 1 :] + nonexpanding_key = tuple(i for i in _key if i is not None) + start_a = nonexpanding_key.index(Ellipsis) + stop_a = start_a + (len(shape) - (len(nonexpanding_key) - 1)) + slices = tuple(slice(None) for _ in range(start_a, stop_a)) + start_pos = _key.index(Ellipsis) + _key = _key[:start_pos] + slices + _key[start_pos + 1 :] axes_indices = [] out_shape = [] - for a, i in enumerate(_key): - if isinstance(i, int): - axes_indices.append([i]) + a = 0 + for i in _key: + if i is None: + out_shape.append(1) else: - side = shape[a] - indices = range(side)[i] - axes_indices.append(indices) - out_shape.append(len(indices)) + if isinstance(i, int): + axes_indices.append([i]) + else: + assert isinstance(i, slice) # sanity check + side = shape[a] + indices = range(side)[i] + axes_indices.append(indices) + out_shape.append(len(indices)) + a += 1 out_shape = tuple(out_shape) ph.assert_shape("__getitem__", out.shape, out_shape) assume(all(len(indices) > 0 for indices in axes_indices)) @@ -104,8 +112,6 @@ def test_setitem(shape, data): ) -# TODO: make mask tests optional - @pytest.mark.data_dependent_shapes @given(hh.shapes(), st.data()) def test_getitem_masking(shape, data): diff --git a/requirements.txt b/requirements.txt index fbc3fca3..b3b26223 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ pytest -hypothesis>=6.31.1 +hypothesis>=6.45.0 ndindex>=1.6