Skip to content

BUG: 2D indexing on DTA/TDA/PA #33290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,18 +550,15 @@ def __getitem__(self, key):
key = np.asarray(key, dtype=bool)

key = check_array_indexer(self, key)
if key.all():
key = slice(0, None, None)
else:
key = lib.maybe_booleans_to_slice(key.view(np.uint8))
key = lib.maybe_booleans_to_slice(key.view(np.uint8))
elif isinstance(key, list) and len(key) == 1 and isinstance(key[0], slice):
# see https://github.com/pandas-dev/pandas/issues/31299, need to allow
# this for now (would otherwise raise in check_array_indexer)
pass
else:
key = check_array_indexer(self, key)

is_period = is_period_dtype(self)
is_period = is_period_dtype(self.dtype)
if is_period:
freq = self.freq
else:
Expand All @@ -577,11 +574,6 @@ def __getitem__(self, key):
freq = self.freq

result = getitem(key)
if result.ndim > 1:
# To support MPL which performs slicing with 2 dim
# even though it only has 1 dim by definition
return result

return self._simple_new(result, dtype=self.dtype, freq=freq)

def __setitem__(
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/indexes/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,10 @@ class ExtensionIndex(Index):
def __getitem__(self, key):
result = self._data[key]
if isinstance(result, type(self._data)):
return type(self)(result, name=self.name)
if result.ndim == 1:
return type(self)(result, name=self.name)
# Unpack to ndarray for MPL compat
result = result._data

# Includes cases where we get a 2D ndarray back for MPL compat
deprecate_ndim_indexing(result)
Expand Down
29 changes: 29 additions & 0 deletions pandas/tests/arrays/test_datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def timedelta_index(request):
class SharedTests:
index_cls: Type[Union[DatetimeIndex, PeriodIndex, TimedeltaIndex]]

@pytest.fixture
def arr1d(self):
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
arr = self.array_cls(data, freq="D")
return arr

def test_compare_len1_raises(self):
# make sure we raise when comparing with different lengths, specific
# to the case where one has length-1, which numpy would broadcast
Expand Down Expand Up @@ -204,6 +210,18 @@ def test_searchsorted(self):
result = arr.searchsorted(pd.NaT)
assert result == 0

def test_getitem_2d(self, arr1d):
# 2d slicing on a 1D array
expected = type(arr1d)(arr1d._data[:, np.newaxis], dtype=arr1d.dtype)
result = arr1d[:, np.newaxis]
tm.assert_equal(result, expected)

# Lookup on a 2D array
arr2d = expected
expected = type(arr2d)(arr2d._data[:3, 0], dtype=arr2d.dtype)
result = arr2d[:3, 0]
tm.assert_equal(result, expected)

def test_setitem(self):
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
arr = self.array_cls(data, freq="D")
Expand Down Expand Up @@ -265,6 +283,13 @@ class TestDatetimeArray(SharedTests):
array_cls = DatetimeArray
dtype = pd.Timestamp

@pytest.fixture
def arr1d(self, tz_naive_fixture):
tz = tz_naive_fixture
dti = pd.date_range("2016-01-01 01:01:00", periods=3, freq="H", tz=tz)
dta = dti._data
return dta

def test_round(self, tz_naive_fixture):
# GH#24064
tz = tz_naive_fixture
Expand Down Expand Up @@ -645,6 +670,10 @@ class TestPeriodArray(SharedTests):
array_cls = PeriodArray
dtype = pd.Period

@pytest.fixture
def arr1d(self, period_index):
return period_index._data

def test_from_pi(self, period_index):
pi = period_index
arr = PeriodArray(pi)
Expand Down