Skip to content

Commit d5bc5ae

Browse files
authored
REF: share delete, putmask, insert between ndarray-backed EA indexes (#37529)
1 parent 4c9fa96 commit d5bc5ae

File tree

5 files changed

+96
-100
lines changed

5 files changed

+96
-100
lines changed

pandas/core/arrays/_mixins.py

+4
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ def _box_func(self, x):
4444
"""
4545
return x
4646

47+
def _validate_insert_value(self, value):
48+
# used by NDArrayBackedExtensionIndex.insert
49+
raise AbstractMethodError(self)
50+
4751
# ------------------------------------------------------------------------
4852

4953
def take(

pandas/core/indexes/category.py

+2-51
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pandas.core.construction import extract_array
2828
import pandas.core.indexes.base as ibase
2929
from pandas.core.indexes.base import Index, _index_shared_docs, maybe_extract_name
30-
from pandas.core.indexes.extension import ExtensionIndex, inherit_names
30+
from pandas.core.indexes.extension import NDArrayBackedExtensionIndex, inherit_names
3131
import pandas.core.missing as missing
3232
from pandas.core.ops import get_op_result_name
3333

@@ -66,7 +66,7 @@
6666
typ="method",
6767
overwrite=True,
6868
)
69-
class CategoricalIndex(ExtensionIndex, accessor.PandasDelegate):
69+
class CategoricalIndex(NDArrayBackedExtensionIndex, accessor.PandasDelegate):
7070
"""
7171
Index based on an underlying :class:`Categorical`.
7272
@@ -425,17 +425,6 @@ def where(self, cond, other=None):
425425
cat = Categorical(values, dtype=self.dtype)
426426
return type(self)._simple_new(cat, name=self.name)
427427

428-
def putmask(self, mask, value):
429-
try:
430-
code_value = self._data._validate_where_value(value)
431-
except (TypeError, ValueError):
432-
return self.astype(object).putmask(mask, value)
433-
434-
codes = self._data._ndarray.copy()
435-
np.putmask(codes, mask, code_value)
436-
cat = self._data._from_backing_data(codes)
437-
return type(self)._simple_new(cat, name=self.name)
438-
439428
def reindex(self, target, method=None, level=None, limit=None, tolerance=None):
440429
"""
441430
Create index with target's values (move/add/delete values as necessary)
@@ -665,44 +654,6 @@ def map(self, mapper):
665654
mapped = self._values.map(mapper)
666655
return Index(mapped, name=self.name)
667656

668-
def delete(self, loc):
669-
"""
670-
Make new Index with passed location(-s) deleted
671-
672-
Returns
673-
-------
674-
new_index : Index
675-
"""
676-
codes = np.delete(self.codes, loc)
677-
cat = self._data._from_backing_data(codes)
678-
return type(self)._simple_new(cat, name=self.name)
679-
680-
def insert(self, loc: int, item):
681-
"""
682-
Make new Index inserting new item at location. Follows
683-
Python list.append semantics for negative values
684-
685-
Parameters
686-
----------
687-
loc : int
688-
item : object
689-
690-
Returns
691-
-------
692-
new_index : Index
693-
694-
Raises
695-
------
696-
ValueError if the item is not in the categories
697-
698-
"""
699-
code = self._data._validate_insert_value(item)
700-
701-
codes = self.codes
702-
codes = np.concatenate((codes[:loc], [code], codes[loc:]))
703-
cat = self._data._from_backing_data(codes)
704-
return type(self)._simple_new(cat, name=self.name)
705-
706657
def _concat(self, to_concat, name):
707658
# if calling index is category, don't check dtype of others
708659
codes = np.concatenate([self._is_dtype_compat(c).codes for c in to_concat])

pandas/core/indexes/datetimelike.py

+32-48
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import pandas.core.indexes.base as ibase
3434
from pandas.core.indexes.base import Index, _index_shared_docs
3535
from pandas.core.indexes.extension import (
36-
ExtensionIndex,
36+
NDArrayBackedExtensionIndex,
3737
inherit_names,
3838
make_wrapped_arith_op,
3939
)
@@ -82,7 +82,7 @@ def wrapper(left, right):
8282
cache=True,
8383
)
8484
@inherit_names(["mean", "asi8", "freq", "freqstr"], DatetimeLikeArrayMixin)
85-
class DatetimeIndexOpsMixin(ExtensionIndex):
85+
class DatetimeIndexOpsMixin(NDArrayBackedExtensionIndex):
8686
"""
8787
Common ops mixin to support a unified interface datetimelike Index.
8888
"""
@@ -191,7 +191,7 @@ def take(self, indices, axis=0, allow_fill=True, fill_value=None, **kwargs):
191191

192192
maybe_slice = lib.maybe_indices_to_slice(indices, len(self))
193193

194-
result = ExtensionIndex.take(
194+
result = NDArrayBackedExtensionIndex.take(
195195
self, indices, axis, allow_fill, fill_value, **kwargs
196196
)
197197
if isinstance(maybe_slice, slice):
@@ -496,17 +496,6 @@ def where(self, cond, other=None):
496496
arr = self._data._from_backing_data(result)
497497
return type(self)._simple_new(arr, name=self.name)
498498

499-
def putmask(self, mask, value):
500-
try:
501-
value = self._data._validate_where_value(value)
502-
except (TypeError, ValueError):
503-
return self.astype(object).putmask(mask, value)
504-
505-
result = self._data._ndarray.copy()
506-
np.putmask(result, mask, value)
507-
arr = self._data._from_backing_data(result)
508-
return type(self)._simple_new(arr, name=self.name)
509-
510499
def _summary(self, name=None) -> str:
511500
"""
512501
Return a summarized representation.
@@ -575,41 +564,30 @@ def shift(self, periods=1, freq=None):
575564
# --------------------------------------------------------------------
576565
# List-like Methods
577566

578-
def delete(self, loc):
579-
new_i8s = np.delete(self.asi8, loc)
580-
567+
def _get_delete_freq(self, loc: int):
568+
"""
569+
Find the `freq` for self.delete(loc).
570+
"""
581571
freq = None
582572
if is_period_dtype(self.dtype):
583573
freq = self.freq
584-
elif is_integer(loc):
585-
if loc in (0, -len(self), -1, len(self) - 1):
586-
freq = self.freq
587-
else:
588-
if is_list_like(loc):
589-
loc = lib.maybe_indices_to_slice(
590-
np.asarray(loc, dtype=np.intp), len(self)
591-
)
592-
if isinstance(loc, slice) and loc.step in (1, None):
593-
if loc.start in (0, None) or loc.stop in (len(self), None):
574+
elif self.freq is not None:
575+
if is_integer(loc):
576+
if loc in (0, -len(self), -1, len(self) - 1):
594577
freq = self.freq
578+
else:
579+
if is_list_like(loc):
580+
loc = lib.maybe_indices_to_slice(
581+
np.asarray(loc, dtype=np.intp), len(self)
582+
)
583+
if isinstance(loc, slice) and loc.step in (1, None):
584+
if loc.start in (0, None) or loc.stop in (len(self), None):
585+
freq = self.freq
586+
return freq
595587

596-
arr = type(self._data)._simple_new(new_i8s, dtype=self.dtype, freq=freq)
597-
return type(self)._simple_new(arr, name=self.name)
598-
599-
def insert(self, loc: int, item):
588+
def _get_insert_freq(self, loc, item):
600589
"""
601-
Make new Index inserting new item at location
602-
603-
Parameters
604-
----------
605-
loc : int
606-
item : object
607-
if not either a Python datetime or a numpy integer-like, returned
608-
Index dtype will be object rather than datetime.
609-
610-
Returns
611-
-------
612-
new_index : Index
590+
Find the `freq` for self.insert(loc, item).
613591
"""
614592
value = self._data._validate_insert_value(item)
615593
item = self._data._box_func(value)
@@ -630,14 +608,20 @@ def insert(self, loc: int, item):
630608
# Adding a single item to an empty index may preserve freq
631609
if self.freq.is_on_offset(item):
632610
freq = self.freq
611+
return freq
633612

634-
arr = self._data
613+
@doc(NDArrayBackedExtensionIndex.delete)
614+
def delete(self, loc):
615+
result = super().delete(loc)
616+
result._data._freq = self._get_delete_freq(loc)
617+
return result
635618

636-
new_values = np.concatenate([arr._ndarray[:loc], [value], arr._ndarray[loc:]])
637-
new_arr = self._data._from_backing_data(new_values)
638-
new_arr._freq = freq
619+
@doc(NDArrayBackedExtensionIndex.insert)
620+
def insert(self, loc: int, item):
621+
result = super().insert(loc, item)
639622

640-
return type(self)._simple_new(new_arr, name=self.name)
623+
result._data._freq = self._get_insert_freq(loc, item)
624+
return result
641625

642626
# --------------------------------------------------------------------
643627
# Join/Set Methods

pandas/core/indexes/extension.py

+57
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
1414

1515
from pandas.core.arrays import ExtensionArray
16+
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
1617
from pandas.core.indexers import deprecate_ndim_indexing
1718
from pandas.core.indexes.base import Index
1819
from pandas.core.ops import get_op_result_name
@@ -281,3 +282,59 @@ def astype(self, dtype, copy=True):
281282
@cache_readonly
282283
def _isnan(self) -> np.ndarray:
283284
return self._data.isna()
285+
286+
287+
class NDArrayBackedExtensionIndex(ExtensionIndex):
288+
"""
289+
Index subclass for indexes backed by NDArrayBackedExtensionArray.
290+
"""
291+
292+
_data: NDArrayBackedExtensionArray
293+
294+
def delete(self, loc):
295+
"""
296+
Make new Index with passed location(-s) deleted
297+
298+
Returns
299+
-------
300+
new_index : Index
301+
"""
302+
new_vals = np.delete(self._data._ndarray, loc)
303+
arr = self._data._from_backing_data(new_vals)
304+
return type(self)._simple_new(arr, name=self.name)
305+
306+
def insert(self, loc: int, item):
307+
"""
308+
Make new Index inserting new item at location. Follows
309+
Python list.append semantics for negative values.
310+
311+
Parameters
312+
----------
313+
loc : int
314+
item : object
315+
316+
Returns
317+
-------
318+
new_index : Index
319+
320+
Raises
321+
------
322+
ValueError if the item is not valid for this dtype.
323+
"""
324+
arr = self._data
325+
code = arr._validate_insert_value(item)
326+
327+
new_vals = np.concatenate((arr._ndarray[:loc], [code], arr._ndarray[loc:]))
328+
new_arr = arr._from_backing_data(new_vals)
329+
return type(self)._simple_new(new_arr, name=self.name)
330+
331+
def putmask(self, mask, value):
332+
try:
333+
value = self._data._validate_where_value(value)
334+
except (TypeError, ValueError):
335+
return self.astype(object).putmask(mask, value)
336+
337+
new_values = self._data._ndarray.copy()
338+
np.putmask(new_values, mask, value)
339+
new_arr = self._data._from_backing_data(new_values)
340+
return type(self)._simple_new(new_arr, name=self.name)

pandas/core/indexes/interval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ def where(self, cond, other=None):
872872
other = self._na_value
873873
values = np.where(cond, self._values, other)
874874
result = IntervalArray(values)
875-
return self._shallow_copy(result)
875+
return type(self)._simple_new(result, name=self.name)
876876

877877
def delete(self, loc):
878878
"""

0 commit comments

Comments
 (0)