diff --git a/pandas/core/indexes/datetimes.py b/pandas/core/indexes/datetimes.py index 0c40900d54b53..4697f2d2e59a4 100644 --- a/pandas/core/indexes/datetimes.py +++ b/pandas/core/indexes/datetimes.py @@ -831,20 +831,6 @@ def slice_indexer(self, start=None, end=None, step=None, kind=None): else: raise - # -------------------------------------------------------------------- - # Wrapping DatetimeArray - - def __getitem__(self, key): - result = self._data.__getitem__(key) - if is_scalar(result): - return result - elif result.ndim > 1: - # To support MPL which performs slicing with 2 dim - # even though it only has 1 dim by definition - assert isinstance(result, np.ndarray), result - return result - return type(self)(result, name=self.name) - # -------------------------------------------------------------------- @Substitution(klass="DatetimeIndex") diff --git a/pandas/core/indexes/extension.py b/pandas/core/indexes/extension.py index 9011616dfe496..bd089f574a313 100644 --- a/pandas/core/indexes/extension.py +++ b/pandas/core/indexes/extension.py @@ -3,6 +3,8 @@ """ from typing import List +import numpy as np + from pandas.compat.numpy import function as nv from pandas.util._decorators import cache_readonly @@ -170,6 +172,29 @@ class ExtensionIndex(Index): __le__ = _make_wrapped_comparison_op("__le__") __ge__ = _make_wrapped_comparison_op("__ge__") + def __getitem__(self, key): + result = self._data[key] + if isinstance(result, type(self._data)): + return type(self)(result, name=self.name) + + # Includes cases where we get a 2D ndarray back for MPL compat + return result + + def __iter__(self): + return self._data.__iter__() + + @property + def _ndarray_values(self) -> np.ndarray: + return self._data._ndarray_values + + def dropna(self, how="any"): + if how not in ("any", "all"): + raise ValueError(f"invalid how option: {how}") + + if self.hasnans: + return self._shallow_copy(self._data[~self._isnan]) + return self._shallow_copy() + def repeat(self, repeats, axis=None): nv.validate_repeat(tuple(), dict(axis=axis)) result = self._data.repeat(repeats, axis=axis) diff --git a/pandas/core/indexes/timedeltas.py b/pandas/core/indexes/timedeltas.py index 86dd4525c7d6d..1f3182bc83e1d 100644 --- a/pandas/core/indexes/timedeltas.py +++ b/pandas/core/indexes/timedeltas.py @@ -211,15 +211,6 @@ def _formatter_func(self): return _get_format_timedelta64(self, box=True) - # ------------------------------------------------------------------- - # Wrapping TimedeltaArray - - def __getitem__(self, key): - result = self._data.__getitem__(key) - if is_scalar(result): - return result - return type(self)(result, name=self.name) - # ------------------------------------------------------------------- @Appender(_index_shared_docs["astype"])