From 03d07dfe353d37de5865879ea5acd237b4df1772 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Sat, 4 Apr 2020 12:22:16 -0700 Subject: [PATCH] BUG: 2D indexing on DTA/TDA/PA --- pandas/core/arrays/datetimelike.py | 12 ++-------- pandas/core/indexes/extension.py | 5 +++- pandas/tests/arrays/test_datetimelike.py | 29 ++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index c0bbbebac7c33..4fabd8f558fee 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -550,10 +550,7 @@ def __getitem__(self, key): key = np.asarray(key, dtype=bool) key = check_array_indexer(self, key) - if key.all(): - key = slice(0, None, None) - else: - key = lib.maybe_booleans_to_slice(key.view(np.uint8)) + key = lib.maybe_booleans_to_slice(key.view(np.uint8)) elif isinstance(key, list) and len(key) == 1 and isinstance(key[0], slice): # see https://github.com/pandas-dev/pandas/issues/31299, need to allow # this for now (would otherwise raise in check_array_indexer) @@ -561,7 +558,7 @@ def __getitem__(self, key): else: key = check_array_indexer(self, key) - is_period = is_period_dtype(self) + is_period = is_period_dtype(self.dtype) if is_period: freq = self.freq else: @@ -577,11 +574,6 @@ def __getitem__(self, key): freq = self.freq result = getitem(key) - if result.ndim > 1: - # To support MPL which performs slicing with 2 dim - # even though it only has 1 dim by definition - return result - return self._simple_new(result, dtype=self.dtype, freq=freq) def __setitem__( diff --git a/pandas/core/indexes/extension.py b/pandas/core/indexes/extension.py index f38a4fb83c64f..c752990531b34 100644 --- a/pandas/core/indexes/extension.py +++ b/pandas/core/indexes/extension.py @@ -214,7 +214,10 @@ class ExtensionIndex(Index): def __getitem__(self, key): result = self._data[key] if isinstance(result, type(self._data)): - return type(self)(result, name=self.name) + if result.ndim == 1: + return type(self)(result, name=self.name) + # Unpack to ndarray for MPL compat + result = result._data # Includes cases where we get a 2D ndarray back for MPL compat deprecate_ndim_indexing(result) diff --git a/pandas/tests/arrays/test_datetimelike.py b/pandas/tests/arrays/test_datetimelike.py index 83995ab26cb56..fe35344f46688 100644 --- a/pandas/tests/arrays/test_datetimelike.py +++ b/pandas/tests/arrays/test_datetimelike.py @@ -60,6 +60,12 @@ def timedelta_index(request): class SharedTests: index_cls: Type[Union[DatetimeIndex, PeriodIndex, TimedeltaIndex]] + @pytest.fixture + def arr1d(self): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9 + arr = self.array_cls(data, freq="D") + return arr + def test_compare_len1_raises(self): # make sure we raise when comparing with different lengths, specific # to the case where one has length-1, which numpy would broadcast @@ -204,6 +210,18 @@ def test_searchsorted(self): result = arr.searchsorted(pd.NaT) assert result == 0 + def test_getitem_2d(self, arr1d): + # 2d slicing on a 1D array + expected = type(arr1d)(arr1d._data[:, np.newaxis], dtype=arr1d.dtype) + result = arr1d[:, np.newaxis] + tm.assert_equal(result, expected) + + # Lookup on a 2D array + arr2d = expected + expected = type(arr2d)(arr2d._data[:3, 0], dtype=arr2d.dtype) + result = arr2d[:3, 0] + tm.assert_equal(result, expected) + def test_setitem(self): data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9 arr = self.array_cls(data, freq="D") @@ -265,6 +283,13 @@ class TestDatetimeArray(SharedTests): array_cls = DatetimeArray dtype = pd.Timestamp + @pytest.fixture + def arr1d(self, tz_naive_fixture): + tz = tz_naive_fixture + dti = pd.date_range("2016-01-01 01:01:00", periods=3, freq="H", tz=tz) + dta = dti._data + return dta + def test_round(self, tz_naive_fixture): # GH#24064 tz = tz_naive_fixture @@ -645,6 +670,10 @@ class TestPeriodArray(SharedTests): array_cls = PeriodArray dtype = pd.Period + @pytest.fixture + def arr1d(self, period_index): + return period_index._data + def test_from_pi(self, period_index): pi = period_index arr = PeriodArray(pi)