Skip to content

REF: de-duplicate ExtensionIndex methods #41791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,36 @@ def where(
res_values = np.where(mask, self._ndarray, value)
return self._from_backing_data(res_values)

# ------------------------------------------------------------------------
# Index compat methods

def insert(
self: NDArrayBackedExtensionArrayT, loc: int, item
) -> NDArrayBackedExtensionArrayT:
"""
Make new ExtensionArray inserting new item at location. Follows
Python list.append semantics for negative values.

Parameters
----------
loc : int
item : object

Returns
-------
type(self)
"""
code = self._validate_scalar(item)

new_vals = np.concatenate(
(
self._ndarray[:loc],
np.asarray([code], dtype=self._ndarray.dtype),
self._ndarray[loc:],
)
)
return self._from_backing_data(new_vals)

# ------------------------------------------------------------------------
# Additional array methods
# These are not part of the EA API, but we implement them because
Expand Down
14 changes: 12 additions & 2 deletions pandas/core/dtypes/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@
is_string_dtype,
needs_i8_conversion,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.dtypes import (
ExtensionDtype,
IntervalDtype,
PeriodDtype,
)
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCExtensionArray,
Expand Down Expand Up @@ -630,7 +634,13 @@ def is_valid_na_for_dtype(obj, dtype: DtypeObj) -> bool:
# This is needed for Categorical, but is kind of weird
return True

# must be PeriodDType
elif isinstance(dtype, PeriodDtype):
return not isinstance(obj, (np.datetime64, np.timedelta64, Decimal))

elif isinstance(dtype, IntervalDtype):
return lib.is_float(obj) or obj is None or obj is libmissing.NA

# fallback, default to allowing NaN, None, NA, NaT
return not isinstance(obj, (np.datetime64, np.timedelta64, Decimal))


Expand Down
105 changes: 40 additions & 65 deletions pandas/core/indexes/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,12 @@

from pandas._typing import ArrayLike
from pandas.compat.numpy import function as nv
from pandas.errors import AbstractMethodError
from pandas.util._decorators import (
cache_readonly,
doc,
)
from pandas.util._exceptions import rewrite_exception

from pandas.core.dtypes.cast import (
find_common_type,
infer_dtype_from,
)
from pandas.core.dtypes.common import (
is_dtype_equal,
is_object_dtype,
Expand All @@ -34,6 +29,7 @@
ABCSeries,
)

from pandas.core.array_algos.putmask import validate_putmask
from pandas.core.arrays import (
Categorical,
DatetimeArray,
Expand Down Expand Up @@ -297,6 +293,21 @@ def searchsorted(self, value, side="left", sorter=None) -> np.ndarray:
# overriding IndexOpsMixin improves performance GH#38083
return self._data.searchsorted(value, side=side, sorter=sorter)

def putmask(self, mask, value) -> Index:
mask, noop = validate_putmask(self._data, mask)
if noop:
return self.copy()

try:
self._validate_fill_value(value)
except (ValueError, TypeError):
dtype = self._find_common_type_compat(value)
return self.astype(dtype).putmask(mask, value)

arr = self._data.copy()
arr.putmask(mask, value)
return type(self)._simple_new(arr, name=self.name)

# ---------------------------------------------------------------------

def _get_engine_target(self) -> np.ndarray:
Expand All @@ -323,9 +334,30 @@ def repeat(self, repeats, axis=None):
result = self._data.repeat(repeats, axis=axis)
return type(self)._simple_new(result, name=self.name)

def insert(self, loc: int, item):
# ExtensionIndex subclasses must override Index.insert
raise AbstractMethodError(self)
def insert(self, loc: int, item) -> Index:
"""
Make new Index inserting new item at location. Follows
Python list.append semantics for negative values.

Parameters
----------
loc : int
item : object

Returns
-------
new_index : Index
"""
try:
result = self._data.insert(loc, item)
except (ValueError, TypeError):
# e.g. trying to insert an integer into a DatetimeIndex
# We cannot keep the same dtype, so cast to the (often object)
# minimal shared dtype before doing the insert.
dtype = self._find_common_type_compat(item)
return self.astype(dtype).insert(loc, item)
else:
return type(self)._simple_new(result, name=self.name)

def _validate_fill_value(self, value):
"""
Expand Down Expand Up @@ -426,60 +458,3 @@ def _get_engine_target(self) -> np.ndarray:
def _from_join_target(self, result: np.ndarray) -> ArrayLike:
assert result.dtype == self._data._ndarray.dtype
return self._data._from_backing_data(result)

def insert(self: _T, loc: int, item) -> Index:
"""
Make new Index inserting new item at location. Follows
Python list.append semantics for negative values.

Parameters
----------
loc : int
item : object

Returns
-------
new_index : Index

Raises
------
ValueError if the item is not valid for this dtype.
"""
arr = self._data
try:
code = arr._validate_scalar(item)
except (ValueError, TypeError):
# e.g. trying to insert an integer into a DatetimeIndex
# We cannot keep the same dtype, so cast to the (often object)
# minimal shared dtype before doing the insert.
dtype, _ = infer_dtype_from(item, pandas_dtype=True)
dtype = find_common_type([self.dtype, dtype])
return self.astype(dtype).insert(loc, item)
else:
new_vals = np.concatenate(
(
arr._ndarray[:loc],
np.asarray([code], dtype=arr._ndarray.dtype),
arr._ndarray[loc:],
)
)
new_arr = arr._from_backing_data(new_vals)
return type(self)._simple_new(new_arr, name=self.name)

def putmask(self, mask, value) -> Index:
res_values = self._data.copy()
try:
res_values.putmask(mask, value)
except (TypeError, ValueError):
return self.astype(object).putmask(mask, value)

return type(self)._simple_new(res_values, name=self.name)

# error: Argument 1 of "_wrap_joined_index" is incompatible with supertype
# "Index"; supertype defines the argument type as "Union[ExtensionArray, ndarray]"
def _wrap_joined_index( # type: ignore[override]
self: _T, joined: NDArrayBackedExtensionArray, other: _T
) -> _T:
name = get_op_result_name(self, other)

return type(self)._simple_new(joined, name=name)
41 changes: 0 additions & 41 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
take_nd,
unique,
)
from pandas.core.array_algos.putmask import validate_putmask
from pandas.core.arrays.interval import (
IntervalArray,
_interval_shared_docs,
Expand Down Expand Up @@ -854,46 +853,6 @@ def mid(self) -> Index:
def length(self) -> Index:
return Index(self._data.length, copy=False)

def putmask(self, mask, value) -> Index:
mask, noop = validate_putmask(self._data, mask)
if noop:
return self.copy()

try:
self._validate_fill_value(value)
except (ValueError, TypeError):
dtype = self._find_common_type_compat(value)
return self.astype(dtype).putmask(mask, value)

arr = self._data.copy()
arr.putmask(mask, value)
return type(self)._simple_new(arr, name=self.name)

def insert(self, loc: int, item):
"""
Return a new IntervalIndex inserting new item at location. Follows
Python list.append semantics for negative values. Only Interval
objects and NA can be inserted into an IntervalIndex

Parameters
----------
loc : int
item : object

Returns
-------
IntervalIndex
"""
try:
result = self._data.insert(loc, item)
except (ValueError, TypeError):
# e.g trying to insert a string
dtype, _ = infer_dtype_from_scalar(item, pandas_dtype=True)
dtype = find_common_type([self.dtype, dtype])
return self.astype(dtype).insert(loc, item)

return type(self)._simple_new(result, name=self.name)

# --------------------------------------------------------------------
# Rendering Methods
# __repr__ associated methods are based on MultiIndex
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/dtypes/test_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from pandas.core.dtypes.missing import (
array_equivalent,
is_valid_na_for_dtype,
isna,
isnull,
na_value_for_dtype,
Expand Down Expand Up @@ -729,3 +730,12 @@ def test_is_matching_na_nan_matches_none(self):

assert libmissing.is_matching_na(None, np.nan, nan_matches_none=True)
assert libmissing.is_matching_na(np.nan, None, nan_matches_none=True)


class TestIsValidNAForDtype:
def test_is_valid_na_for_dtype_interval(self):
dtype = IntervalDtype("int64", "left")
assert not is_valid_na_for_dtype(NaT, dtype)

dtype = IntervalDtype("datetime64[ns]", "both")
assert not is_valid_na_for_dtype(NaT, dtype)