From 95eed2557f682b6b831ce5f62c1f1a4f862b14b0 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Fri, 3 Jan 2020 15:25:09 -0800 Subject: [PATCH 1/3] implement ExtensionIndex --- pandas/core/indexes/category.py | 10 ++--- pandas/core/indexes/datetimelike.py | 33 +-------------- pandas/core/indexes/datetimes.py | 26 ++---------- pandas/core/indexes/extension.py | 65 ++++++++++++++++++++++++++++- pandas/core/indexes/interval.py | 16 ++----- pandas/core/indexes/period.py | 9 ---- pandas/core/indexes/timedeltas.py | 19 ++------- 7 files changed, 82 insertions(+), 96 deletions(-) diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index 96bfff9a0a09f..3ac1902f35535 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -32,6 +32,8 @@ import pandas.core.missing as missing from pandas.core.ops import get_op_result_name +from .extension import ExtensionIndex + _index_doc_kwargs = dict(ibase._index_doc_kwargs) _index_doc_kwargs.update(dict(target_klass="CategoricalIndex")) @@ -65,7 +67,7 @@ typ="method", overwrite=True, ) -class CategoricalIndex(Index, accessor.PandasDelegate): +class CategoricalIndex(ExtensionIndex, accessor.PandasDelegate): """ Index based on an underlying :class:`Categorical`. @@ -268,14 +270,12 @@ def _create_categorical(cls, data, dtype=None): return data @classmethod - def _simple_new(cls, values, name=None, dtype=None, **kwargs): + def _simple_new(cls, values, name=None, dtype=None): result = object.__new__(cls) values = cls._create_categorical(values, dtype=dtype) result._data = values result.name = name - for k, v in kwargs.items(): - setattr(result, k, v) result._reset_identity() result._no_setting_name = False @@ -415,7 +415,7 @@ def astype(self, dtype, copy=True): if dtype == self.dtype: return self.copy() if copy else self - return super().astype(dtype=dtype, copy=copy) + return Index.astype(self, dtype=dtype, copy=copy) @cache_readonly def _isnan(self): diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 7bf1a601a0ab6..c6ace4097d09a 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -40,7 +40,7 @@ from pandas.tseries.frequencies import DateOffset, to_offset -from .extension import inherit_names +from .extension import ExtensionIndex, inherit_names _index_doc_kwargs = dict(ibase._index_doc_kwargs) @@ -90,7 +90,7 @@ def wrapper(left, right): ["__iter__", "mean", "freq", "freqstr", "_ndarray_values", "asi8", "_box_values"], DatetimeLikeArrayMixin, ) -class DatetimeIndexOpsMixin(ExtensionOpsMixin): +class DatetimeIndexOpsMixin(ExtensionIndex, ExtensionOpsMixin): """ Common ops mixin to support a unified interface datetimelike Index. """ @@ -109,17 +109,6 @@ class DatetimeIndexOpsMixin(ExtensionOpsMixin): def is_all_dates(self) -> bool: return True - def unique(self, level=None): - if level is not None: - self._validate_index_level(level) - - result = self._data.unique() - - # Note: if `self` is already unique, then self.unique() should share - # a `freq` with self. If not already unique, then self.freq must be - # None, so again sharing freq is correct. - return self._shallow_copy(result._data) - @classmethod def _create_comparison_method(cls, op): """ @@ -186,12 +175,6 @@ def equals(self, other): # have different timezone return False - elif is_period_dtype(self): - if not is_period_dtype(other): - return False - if self.freq != other.freq: - return False - return np.array_equal(self.asi8, other.asi8) def _ensure_localized( @@ -574,18 +557,6 @@ def _concat_same_dtype(self, to_concat, name): return self._simple_new(new_data, **attribs) - @Appender(_index_shared_docs["astype"]) - def astype(self, dtype, copy=True): - if is_dtype_equal(self.dtype, dtype) and copy is False: - # Ensure that self.astype(self.dtype) is self - return self - - new_values = self._data.astype(dtype, copy=copy) - - # pass copy=False because any copying will be done in the - # _data.astype call above - return Index(new_values, dtype=new_values.dtype, name=self.name, copy=False) - def shift(self, periods=1, freq=None): """ Shift index by desired number of time frequency increments. diff --git a/pandas/core/indexes/datetimes.py b/pandas/core/indexes/datetimes.py index bc6b8ff845a56..189ef92a5196f 100644 --- a/pandas/core/indexes/datetimes.py +++ b/pandas/core/indexes/datetimes.py @@ -388,18 +388,12 @@ def _formatter_func(self): # -------------------------------------------------------------------- # Set Operation Methods - def _union(self, other, sort): + def _union(self, other: "DatetimeIndex", sort): if not len(other) or self.equals(other) or not len(self): return super()._union(other, sort=sort) - if len(other) == 0 or self.equals(other) or len(self) == 0: - return super().union(other, sort=sort) - - if not isinstance(other, DatetimeIndex): - try: - other = DatetimeIndex(other) - except TypeError: - pass + # We are called by `union`, which is responsible for this validation + assert isinstance(other, DatetimeIndex) this, other = self._maybe_utc_convert(other) @@ -949,20 +943,6 @@ def slice_indexer(self, start=None, end=None, step=None, kind=None): else: raise - # -------------------------------------------------------------------- - # Wrapping DatetimeArray - - def __getitem__(self, key): - result = self._data.__getitem__(key) - if is_scalar(result): - return result - elif result.ndim > 1: - # To support MPL which performs slicing with 2 dim - # even though it only has 1 dim by definition - assert isinstance(result, np.ndarray), result - return result - return type(self)(result, name=self.name) - # -------------------------------------------------------------------- @Substitution(klass="DatetimeIndex") diff --git a/pandas/core/indexes/extension.py b/pandas/core/indexes/extension.py index 779cd8eac4eaf..78ae98f4fec47 100644 --- a/pandas/core/indexes/extension.py +++ b/pandas/core/indexes/extension.py @@ -3,7 +3,15 @@ """ from typing import List -from pandas.util._decorators import cache_readonly +import numpy as np + +from pandas.util._decorators import Appender, cache_readonly + +from pandas.core.dtypes.common import is_dtype_equal + +from pandas.core.arrays import ExtensionArray + +from .base import Index, _index_shared_docs def inherit_from_data(name: str, delegate, cache: bool = False): @@ -76,3 +84,58 @@ def wrapper(cls): return cls return wrapper + + +class ExtensionIndex(Index): + _data: ExtensionArray + + def __getitem__(self, key): + result = self._data[key] + if isinstance(result, type(self._data)): + return type(self)(result, name=self.name) + + # Includes cases where we get a 2D ndarray back for MPL compat + return result + + def __iter__(self): + return self._data.__iter__() + + @property + def _ndarray_values(self) -> np.ndarray: + return self._data._ndarray_values + + @Appender(_index_shared_docs["astype"]) + def astype(self, dtype, copy=True): + if is_dtype_equal(self.dtype, dtype) and copy is False: + # Ensure that self.astype(self.dtype) is self + return self + + new_values = self._data.astype(dtype, copy=copy) + + # pass copy=False because any copying will be done in the + # _data.astype call above + return Index(new_values, dtype=new_values.dtype, name=self.name, copy=False) + + def dropna(self, how="any"): + if how not in ("any", "all"): + raise ValueError(f"invalid how option: {how}") + + if self.hasnans: + return self._shallow_copy(self._data[~self._isnan]) + return self._shallow_copy() + + def _get_unique_index(self, dropna=False): + if self.is_unique and not dropna: + return self + + result = self._data.unique() + if dropna and self.hasnans: + result = result[~result.isna()] + return self._shallow_copy(result) + + def unique(self, level=None): + if level is not None: + self._validate_index_level(level) + + result = self._data.unique() + return self._shallow_copy(result) diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index cae9fa949f711..b2cef6ccd870f 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -58,7 +58,7 @@ from pandas.tseries.frequencies import to_offset from pandas.tseries.offsets import DateOffset -from .extension import inherit_names +from .extension import ExtensionIndex, inherit_names _VALID_CLOSED = {"left", "right", "both", "neither"} _index_doc_kwargs = dict(ibase._index_doc_kwargs) @@ -213,7 +213,7 @@ def func(intvidx_self, other, sort=False): overwrite=True, ) @inherit_names(["is_non_overlapping_monotonic", "mid"], IntervalArray, cache=True) -class IntervalIndex(IntervalMixin, Index, accessor.PandasDelegate): +class IntervalIndex(IntervalMixin, ExtensionIndex, accessor.PandasDelegate): _typ = "intervalindex" _comparables = ["name"] _attributes = ["name", "closed"] @@ -430,7 +430,7 @@ def astype(self, dtype, copy=True): new_values = self.values.astype(dtype, copy=copy) if is_interval_dtype(new_values): return self._shallow_copy(new_values.left, new_values.right) - return super().astype(dtype, copy=copy) + return Index.astype(self, dtype, copy=copy) @property def inferred_type(self) -> str: @@ -664,7 +664,7 @@ def _maybe_convert_i8(self, key): return key_i8 - def _check_method(self, method): + def _check_method(self, method): # TODO: Doesnt need to be a method if method is None: return @@ -994,14 +994,6 @@ def take(self, indices, axis=0, allow_fill=True, fill_value=None, **kwargs): ) return self._shallow_copy(result) - def __getitem__(self, value): - result = self._data[value] - if isinstance(result, IntervalArray): - return self._shallow_copy(result) - else: - # scalar - return result - # -------------------------------------------------------------------- # Rendering Methods # __repr__ associated methods are based on MultiIndex diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index 83a65b6505446..2d9dd7a8487f8 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -599,15 +599,6 @@ def get_indexer_non_unique(self, target): indexer, missing = self._int64index.get_indexer_non_unique(target) return ensure_platform_int(indexer), missing - def _get_unique_index(self, dropna=False): - """ - wrap Index._get_unique_index to handle NaT - """ - res = super()._get_unique_index(dropna=dropna) - if dropna: - res = res.dropna() - return res - def get_loc(self, key, method=None, tolerance=None): """ Get integer location for requested label diff --git a/pandas/core/indexes/timedeltas.py b/pandas/core/indexes/timedeltas.py index 894b430f1c4fd..45a83fc12a094 100644 --- a/pandas/core/indexes/timedeltas.py +++ b/pandas/core/indexes/timedeltas.py @@ -225,15 +225,6 @@ def _formatter_func(self): return _get_format_timedelta64(self, box=True) - # ------------------------------------------------------------------- - # Wrapping TimedeltaArray - - def __getitem__(self, key): - result = self._data.__getitem__(key) - if is_scalar(result): - return result - return type(self)(result, name=self.name) - # ------------------------------------------------------------------- @Appender(_index_shared_docs["astype"]) @@ -249,15 +240,13 @@ def astype(self, dtype, copy=True): return Index(result.astype("i8"), name=self.name) return DatetimeIndexOpsMixin.astype(self, dtype, copy=copy) - def _union(self, other, sort): + def _union(self, other: "TimedeltaIndex", sort): if len(other) == 0 or self.equals(other) or len(self) == 0: return super()._union(other, sort=sort) - if not isinstance(other, TimedeltaIndex): - try: - other = TimedeltaIndex(other) - except (TypeError, ValueError): - pass + # We are called by `union`, which is responsible for this validation + assert isinstance(other, TimedeltaIndex) + this, other = self, other if this._can_fast_union(other): From be7eef6fa7ce409f675ff3040e40cd06f146369a Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Sun, 5 Jan 2020 14:45:35 -0800 Subject: [PATCH 2/3] remove comment --- pandas/core/indexes/interval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 5b827487ca383..a7be345165438 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -673,7 +673,7 @@ def _maybe_convert_i8(self, key): return key_i8 - def _check_method(self, method): # TODO: Doesnt need to be a method + def _check_method(self, method): if method is None: return From 1a5387a466f9af2bfddc1f7ccebf9618cdfd297e Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Mon, 6 Jan 2020 14:39:19 -0800 Subject: [PATCH 3/3] troubleshoot perf --- pandas/core/indexes/interval.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 6727200299292..9ce917b004bc1 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -436,7 +436,7 @@ def astype(self, dtype, copy=True): new_values = self.values.astype(dtype, copy=copy) if is_interval_dtype(new_values): return self._shallow_copy(new_values.left, new_values.right) - return Index.astype(self, dtype, copy=copy) + return super().astype(dtype, copy=copy) @property def inferred_type(self) -> str: @@ -1000,6 +1000,14 @@ def take(self, indices, axis=0, allow_fill=True, fill_value=None, **kwargs): ) return self._shallow_copy(result) + def __getitem__(self, value): + result = self._data[value] + if isinstance(result, IntervalArray): + return self._shallow_copy(result) + else: + # scalar + return result + # -------------------------------------------------------------------- # Rendering Methods # __repr__ associated methods are based on MultiIndex