Skip to content

Commit 7dc65ae

Browse files
authored
REF: de-duplicate ExtensionIndex methods (#41791)
1 parent dc4fddc commit 7dc65ae

File tree

5 files changed

+92
-108
lines changed

5 files changed

+92
-108
lines changed

pandas/core/arrays/_mixins.py

+30
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,36 @@ def where(
346346
res_values = np.where(mask, self._ndarray, value)
347347
return self._from_backing_data(res_values)
348348

349+
# ------------------------------------------------------------------------
350+
# Index compat methods
351+
352+
def insert(
353+
self: NDArrayBackedExtensionArrayT, loc: int, item
354+
) -> NDArrayBackedExtensionArrayT:
355+
"""
356+
Make new ExtensionArray inserting new item at location. Follows
357+
Python list.append semantics for negative values.
358+
359+
Parameters
360+
----------
361+
loc : int
362+
item : object
363+
364+
Returns
365+
-------
366+
type(self)
367+
"""
368+
code = self._validate_scalar(item)
369+
370+
new_vals = np.concatenate(
371+
(
372+
self._ndarray[:loc],
373+
np.asarray([code], dtype=self._ndarray.dtype),
374+
self._ndarray[loc:],
375+
)
376+
)
377+
return self._from_backing_data(new_vals)
378+
349379
# ------------------------------------------------------------------------
350380
# Additional array methods
351381
# These are not part of the EA API, but we implement them because

pandas/core/dtypes/missing.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@
3737
is_string_dtype,
3838
needs_i8_conversion,
3939
)
40-
from pandas.core.dtypes.dtypes import ExtensionDtype
40+
from pandas.core.dtypes.dtypes import (
41+
ExtensionDtype,
42+
IntervalDtype,
43+
PeriodDtype,
44+
)
4145
from pandas.core.dtypes.generic import (
4246
ABCDataFrame,
4347
ABCExtensionArray,
@@ -630,7 +634,13 @@ def is_valid_na_for_dtype(obj, dtype: DtypeObj) -> bool:
630634
# This is needed for Categorical, but is kind of weird
631635
return True
632636

633-
# must be PeriodDType
637+
elif isinstance(dtype, PeriodDtype):
638+
return not isinstance(obj, (np.datetime64, np.timedelta64, Decimal))
639+
640+
elif isinstance(dtype, IntervalDtype):
641+
return lib.is_float(obj) or obj is None or obj is libmissing.NA
642+
643+
# fallback, default to allowing NaN, None, NA, NaT
634644
return not isinstance(obj, (np.datetime64, np.timedelta64, Decimal))
635645

636646

pandas/core/indexes/extension.py

+40-65
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,12 @@
1313

1414
from pandas._typing import ArrayLike
1515
from pandas.compat.numpy import function as nv
16-
from pandas.errors import AbstractMethodError
1716
from pandas.util._decorators import (
1817
cache_readonly,
1918
doc,
2019
)
2120
from pandas.util._exceptions import rewrite_exception
2221

23-
from pandas.core.dtypes.cast import (
24-
find_common_type,
25-
infer_dtype_from,
26-
)
2722
from pandas.core.dtypes.common import (
2823
is_dtype_equal,
2924
is_object_dtype,
@@ -34,6 +29,7 @@
3429
ABCSeries,
3530
)
3631

32+
from pandas.core.array_algos.putmask import validate_putmask
3733
from pandas.core.arrays import (
3834
Categorical,
3935
DatetimeArray,
@@ -297,6 +293,21 @@ def searchsorted(self, value, side="left", sorter=None) -> np.ndarray:
297293
# overriding IndexOpsMixin improves performance GH#38083
298294
return self._data.searchsorted(value, side=side, sorter=sorter)
299295

296+
def putmask(self, mask, value) -> Index:
297+
mask, noop = validate_putmask(self._data, mask)
298+
if noop:
299+
return self.copy()
300+
301+
try:
302+
self._validate_fill_value(value)
303+
except (ValueError, TypeError):
304+
dtype = self._find_common_type_compat(value)
305+
return self.astype(dtype).putmask(mask, value)
306+
307+
arr = self._data.copy()
308+
arr.putmask(mask, value)
309+
return type(self)._simple_new(arr, name=self.name)
310+
300311
# ---------------------------------------------------------------------
301312

302313
def _get_engine_target(self) -> np.ndarray:
@@ -323,9 +334,30 @@ def repeat(self, repeats, axis=None):
323334
result = self._data.repeat(repeats, axis=axis)
324335
return type(self)._simple_new(result, name=self.name)
325336

326-
def insert(self, loc: int, item):
327-
# ExtensionIndex subclasses must override Index.insert
328-
raise AbstractMethodError(self)
337+
def insert(self, loc: int, item) -> Index:
338+
"""
339+
Make new Index inserting new item at location. Follows
340+
Python list.append semantics for negative values.
341+
342+
Parameters
343+
----------
344+
loc : int
345+
item : object
346+
347+
Returns
348+
-------
349+
new_index : Index
350+
"""
351+
try:
352+
result = self._data.insert(loc, item)
353+
except (ValueError, TypeError):
354+
# e.g. trying to insert an integer into a DatetimeIndex
355+
# We cannot keep the same dtype, so cast to the (often object)
356+
# minimal shared dtype before doing the insert.
357+
dtype = self._find_common_type_compat(item)
358+
return self.astype(dtype).insert(loc, item)
359+
else:
360+
return type(self)._simple_new(result, name=self.name)
329361

330362
def _validate_fill_value(self, value):
331363
"""
@@ -426,60 +458,3 @@ def _get_engine_target(self) -> np.ndarray:
426458
def _from_join_target(self, result: np.ndarray) -> ArrayLike:
427459
assert result.dtype == self._data._ndarray.dtype
428460
return self._data._from_backing_data(result)
429-
430-
def insert(self: _T, loc: int, item) -> Index:
431-
"""
432-
Make new Index inserting new item at location. Follows
433-
Python list.append semantics for negative values.
434-
435-
Parameters
436-
----------
437-
loc : int
438-
item : object
439-
440-
Returns
441-
-------
442-
new_index : Index
443-
444-
Raises
445-
------
446-
ValueError if the item is not valid for this dtype.
447-
"""
448-
arr = self._data
449-
try:
450-
code = arr._validate_scalar(item)
451-
except (ValueError, TypeError):
452-
# e.g. trying to insert an integer into a DatetimeIndex
453-
# We cannot keep the same dtype, so cast to the (often object)
454-
# minimal shared dtype before doing the insert.
455-
dtype, _ = infer_dtype_from(item, pandas_dtype=True)
456-
dtype = find_common_type([self.dtype, dtype])
457-
return self.astype(dtype).insert(loc, item)
458-
else:
459-
new_vals = np.concatenate(
460-
(
461-
arr._ndarray[:loc],
462-
np.asarray([code], dtype=arr._ndarray.dtype),
463-
arr._ndarray[loc:],
464-
)
465-
)
466-
new_arr = arr._from_backing_data(new_vals)
467-
return type(self)._simple_new(new_arr, name=self.name)
468-
469-
def putmask(self, mask, value) -> Index:
470-
res_values = self._data.copy()
471-
try:
472-
res_values.putmask(mask, value)
473-
except (TypeError, ValueError):
474-
return self.astype(object).putmask(mask, value)
475-
476-
return type(self)._simple_new(res_values, name=self.name)
477-
478-
# error: Argument 1 of "_wrap_joined_index" is incompatible with supertype
479-
# "Index"; supertype defines the argument type as "Union[ExtensionArray, ndarray]"
480-
def _wrap_joined_index( # type: ignore[override]
481-
self: _T, joined: NDArrayBackedExtensionArray, other: _T
482-
) -> _T:
483-
name = get_op_result_name(self, other)
484-
485-
return type(self)._simple_new(joined, name=name)

pandas/core/indexes/interval.py

-41
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
take_nd,
6868
unique,
6969
)
70-
from pandas.core.array_algos.putmask import validate_putmask
7170
from pandas.core.arrays.interval import (
7271
IntervalArray,
7372
_interval_shared_docs,
@@ -854,46 +853,6 @@ def mid(self) -> Index:
854853
def length(self) -> Index:
855854
return Index(self._data.length, copy=False)
856855

857-
def putmask(self, mask, value) -> Index:
858-
mask, noop = validate_putmask(self._data, mask)
859-
if noop:
860-
return self.copy()
861-
862-
try:
863-
self._validate_fill_value(value)
864-
except (ValueError, TypeError):
865-
dtype = self._find_common_type_compat(value)
866-
return self.astype(dtype).putmask(mask, value)
867-
868-
arr = self._data.copy()
869-
arr.putmask(mask, value)
870-
return type(self)._simple_new(arr, name=self.name)
871-
872-
def insert(self, loc: int, item):
873-
"""
874-
Return a new IntervalIndex inserting new item at location. Follows
875-
Python list.append semantics for negative values. Only Interval
876-
objects and NA can be inserted into an IntervalIndex
877-
878-
Parameters
879-
----------
880-
loc : int
881-
item : object
882-
883-
Returns
884-
-------
885-
IntervalIndex
886-
"""
887-
try:
888-
result = self._data.insert(loc, item)
889-
except (ValueError, TypeError):
890-
# e.g trying to insert a string
891-
dtype, _ = infer_dtype_from_scalar(item, pandas_dtype=True)
892-
dtype = find_common_type([self.dtype, dtype])
893-
return self.astype(dtype).insert(loc, item)
894-
895-
return type(self)._simple_new(result, name=self.name)
896-
897856
# --------------------------------------------------------------------
898857
# Rendering Methods
899858
# __repr__ associated methods are based on MultiIndex

pandas/tests/dtypes/test_missing.py

+10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from pandas.core.dtypes.missing import (
2626
array_equivalent,
27+
is_valid_na_for_dtype,
2728
isna,
2829
isnull,
2930
na_value_for_dtype,
@@ -729,3 +730,12 @@ def test_is_matching_na_nan_matches_none(self):
729730

730731
assert libmissing.is_matching_na(None, np.nan, nan_matches_none=True)
731732
assert libmissing.is_matching_na(np.nan, None, nan_matches_none=True)
733+
734+
735+
class TestIsValidNAForDtype:
736+
def test_is_valid_na_for_dtype_interval(self):
737+
dtype = IntervalDtype("int64", "left")
738+
assert not is_valid_na_for_dtype(NaT, dtype)
739+
740+
dtype = IntervalDtype("datetime64[ns]", "both")
741+
assert not is_valid_na_for_dtype(NaT, dtype)

0 commit comments

Comments
 (0)