Skip to content

BUG: fix tuple array indexing #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,8 +698,13 @@ def __getitem__(
# docstring of _validate_index
self._validate_index(key)
if isinstance(key, Array):
key = (key,)
if isinstance(key, tuple):
# Indexing self._array with array_api_strict arrays can be erroneous
key = key._array
# e.g., when using non-default device
key = tuple(
subkey._array if isinstance(subkey, Array) else subkey for subkey in key
)
res = self._array.__getitem__(key)
return self._new(res, device=self.device)

Expand Down
22 changes: 13 additions & 9 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pytest

from .. import ones, arange, reshape, asarray, result_type, all, equal
from .. import ones, arange, reshape, asarray, result_type, all, equal, stack
from .._array_object import Array, CPU_DEVICE, Device
from .._dtypes import (
_all_dtypes,
Expand Down Expand Up @@ -101,33 +101,37 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[idx])


def test_indexing_arrays():
# @pytest.mark.parametrize("device", ["CPU_DEVICE", "device1", "device2"])
def test_indexing_arrays(device='device1'):
# indexing with 1D integer arrays and mixes of integers and 1D integer are allowed
device = Device(device)

# 1D array
a = arange(5)
idx = asarray([1, 0, 1, 2, -1])
idx = asarray([1, 0, 1, 2, -1], device=device)
a_idx = a[idx]

a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])])
a_idx_loop = stack([a[idx[i]] for i in range(idx.shape[0])])
assert all(a_idx == a_idx_loop)
assert a_idx.shape == idx.shape

# setitem with arrays is not allowed
with assert_raises(IndexError):
a[idx] = 42

# mixed array and integer indexing
a = reshape(arange(3*4), (3, 4))
idx = asarray([1, 0, 1, 2, -1])
a = reshape(arange(3*4, device=device), (3, 4))
idx = asarray([1, 0, 1, 2, -1], device=device)
a_idx = a[idx, 1]

a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])])
a_idx_loop = stack([a[idx[i], 1] for i in range(idx.shape[0])])
assert all(a_idx == a_idx_loop)
assert a_idx.shape == idx.shape

# index with two arrays
a_idx = a[idx, idx]
a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])])
a_idx_loop = stack([a[idx[i], idx[i]] for i in range(idx.shape[0])])
assert all(a_idx == a_idx_loop)
assert a_idx.shape == a_idx.shape

# setitem with arrays is not allowed
with assert_raises(IndexError):
Expand Down
Loading