Skip to content

Commit 586f63b

Browse files
authored
BUG: 2D indexing on DTA/TDA/PA (#33290)
1 parent 080004c commit 586f63b

File tree

3 files changed

+35
-11
lines changed

3 files changed

+35
-11
lines changed

pandas/core/arrays/datetimelike.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -550,18 +550,15 @@ def __getitem__(self, key):
550550
key = np.asarray(key, dtype=bool)
551551

552552
key = check_array_indexer(self, key)
553-
if key.all():
554-
key = slice(0, None, None)
555-
else:
556-
key = lib.maybe_booleans_to_slice(key.view(np.uint8))
553+
key = lib.maybe_booleans_to_slice(key.view(np.uint8))
557554
elif isinstance(key, list) and len(key) == 1 and isinstance(key[0], slice):
558555
# see https://github.com/pandas-dev/pandas/issues/31299, need to allow
559556
# this for now (would otherwise raise in check_array_indexer)
560557
pass
561558
else:
562559
key = check_array_indexer(self, key)
563560

564-
is_period = is_period_dtype(self)
561+
is_period = is_period_dtype(self.dtype)
565562
if is_period:
566563
freq = self.freq
567564
else:
@@ -577,11 +574,6 @@ def __getitem__(self, key):
577574
freq = self.freq
578575

579576
result = getitem(key)
580-
if result.ndim > 1:
581-
# To support MPL which performs slicing with 2 dim
582-
# even though it only has 1 dim by definition
583-
return result
584-
585577
return self._simple_new(result, dtype=self.dtype, freq=freq)
586578

587579
def __setitem__(

pandas/core/indexes/extension.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,10 @@ class ExtensionIndex(Index):
214214
def __getitem__(self, key):
215215
result = self._data[key]
216216
if isinstance(result, type(self._data)):
217-
return type(self)(result, name=self.name)
217+
if result.ndim == 1:
218+
return type(self)(result, name=self.name)
219+
# Unpack to ndarray for MPL compat
220+
result = result._data
218221

219222
# Includes cases where we get a 2D ndarray back for MPL compat
220223
deprecate_ndim_indexing(result)

pandas/tests/arrays/test_datetimelike.py

+29
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ def timedelta_index(request):
6060
class SharedTests:
6161
index_cls: Type[Union[DatetimeIndex, PeriodIndex, TimedeltaIndex]]
6262

63+
@pytest.fixture
64+
def arr1d(self):
65+
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
66+
arr = self.array_cls(data, freq="D")
67+
return arr
68+
6369
def test_compare_len1_raises(self):
6470
# make sure we raise when comparing with different lengths, specific
6571
# to the case where one has length-1, which numpy would broadcast
@@ -204,6 +210,18 @@ def test_searchsorted(self):
204210
result = arr.searchsorted(pd.NaT)
205211
assert result == 0
206212

213+
def test_getitem_2d(self, arr1d):
214+
# 2d slicing on a 1D array
215+
expected = type(arr1d)(arr1d._data[:, np.newaxis], dtype=arr1d.dtype)
216+
result = arr1d[:, np.newaxis]
217+
tm.assert_equal(result, expected)
218+
219+
# Lookup on a 2D array
220+
arr2d = expected
221+
expected = type(arr2d)(arr2d._data[:3, 0], dtype=arr2d.dtype)
222+
result = arr2d[:3, 0]
223+
tm.assert_equal(result, expected)
224+
207225
def test_setitem(self):
208226
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
209227
arr = self.array_cls(data, freq="D")
@@ -265,6 +283,13 @@ class TestDatetimeArray(SharedTests):
265283
array_cls = DatetimeArray
266284
dtype = pd.Timestamp
267285

286+
@pytest.fixture
287+
def arr1d(self, tz_naive_fixture):
288+
tz = tz_naive_fixture
289+
dti = pd.date_range("2016-01-01 01:01:00", periods=3, freq="H", tz=tz)
290+
dta = dti._data
291+
return dta
292+
268293
def test_round(self, tz_naive_fixture):
269294
# GH#24064
270295
tz = tz_naive_fixture
@@ -645,6 +670,10 @@ class TestPeriodArray(SharedTests):
645670
array_cls = PeriodArray
646671
dtype = pd.Period
647672

673+
@pytest.fixture
674+
def arr1d(self, period_index):
675+
return period_index._data
676+
648677
def test_from_pi(self, period_index):
649678
pi = period_index
650679
arr = PeriodArray(pi)

0 commit comments

Comments
 (0)