Skip to content

Commit a535f79

Browse files
jbrockmendelKevin D Smith
authored and
Kevin D Smith
committed
REF: share __getitem__ for Categorical/PandasArray/DTA/TDA/PA (pandas-dev#36391)
1 parent 3acd363 commit a535f79

File tree

4 files changed

+41
-36
lines changed

4 files changed

+41
-36
lines changed

pandas/core/arrays/_mixins.py

+26
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44

5+
from pandas._libs import lib
56
from pandas.compat.numpy import function as nv
67
from pandas.errors import AbstractMethodError
78
from pandas.util._decorators import cache_readonly, doc
@@ -30,6 +31,12 @@ def _from_backing_data(self: _T, arr: np.ndarray) -> _T:
3031
"""
3132
raise AbstractMethodError(self)
3233

34+
def _box_func(self, x):
35+
"""
36+
Wrap numpy type in our dtype.type if necessary.
37+
"""
38+
return x
39+
3340
# ------------------------------------------------------------------------
3441

3542
def take(
@@ -168,3 +175,22 @@ def _validate_setitem_key(self, key):
168175

169176
def _validate_setitem_value(self, value):
170177
return value
178+
179+
def __getitem__(self, key):
180+
if lib.is_integer(key):
181+
# fast-path
182+
result = self._ndarray[key]
183+
if self.ndim == 1:
184+
return self._box_func(result)
185+
return self._from_backing_data(result)
186+
187+
key = self._validate_getitem_key(key)
188+
result = self._ndarray[key]
189+
if lib.is_scalar(result):
190+
return self._box_func(result)
191+
192+
result = self._from_backing_data(result)
193+
return result
194+
195+
def _validate_getitem_key(self, key):
196+
return check_array_indexer(self, key)

pandas/core/arrays/categorical.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -1887,17 +1887,11 @@ def __getitem__(self, key):
18871887
"""
18881888
Return an item.
18891889
"""
1890-
if isinstance(key, (int, np.integer)):
1891-
i = self._codes[key]
1892-
return self._box_func(i)
1893-
1894-
key = check_array_indexer(self, key)
1895-
1896-
result = self._codes[key]
1897-
if result.ndim > 1:
1890+
result = super().__getitem__(key)
1891+
if getattr(result, "ndim", 0) > 1:
1892+
result = result._ndarray
18981893
deprecate_ndim_indexing(result)
1899-
return result
1900-
return self._from_backing_data(result)
1894+
return result
19011895

19021896
def _validate_setitem_value(self, value):
19031897
value = extract_array(value, extract_numpy=True)

pandas/core/arrays/datetimelike.py

+7-16
Original file line numberDiff line numberDiff line change
@@ -539,23 +539,11 @@ def __getitem__(self, key):
539539
This getitem defers to the underlying array, which by-definition can
540540
only handle list-likes, slices, and integer scalars
541541
"""
542-
543-
if lib.is_integer(key):
544-
# fast-path
545-
result = self._ndarray[key]
546-
if self.ndim == 1:
547-
return self._box_func(result)
548-
return self._from_backing_data(result)
549-
550-
key = self._validate_getitem_key(key)
551-
result = self._ndarray[key]
542+
result = super().__getitem__(key)
552543
if lib.is_scalar(result):
553-
return self._box_func(result)
554-
555-
result = self._from_backing_data(result)
544+
return result
556545

557-
freq = self._get_getitem_freq(key)
558-
result._freq = freq
546+
result._freq = self._get_getitem_freq(key)
559547
return result
560548

561549
def _validate_getitem_key(self, key):
@@ -572,7 +560,7 @@ def _validate_getitem_key(self, key):
572560
# this for now (would otherwise raise in check_array_indexer)
573561
pass
574562
else:
575-
key = check_array_indexer(self, key)
563+
key = super()._validate_getitem_key(key)
576564
return key
577565

578566
def _get_getitem_freq(self, key):
@@ -582,7 +570,10 @@ def _get_getitem_freq(self, key):
582570
is_period = is_period_dtype(self.dtype)
583571
if is_period:
584572
freq = self.freq
573+
elif self.ndim != 1:
574+
freq = None
585575
else:
576+
key = self._validate_getitem_key(key) # maybe ndarray[bool] -> slice
586577
freq = None
587578
if isinstance(key, slice):
588579
if self.freq is not None and key.step is not None:

pandas/core/arrays/numpy_.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
2020
from pandas.core.arrays.base import ExtensionOpsMixin
2121
from pandas.core.construction import extract_array
22-
from pandas.core.indexers import check_array_indexer
2322
from pandas.core.missing import backfill_1d, pad_1d
2423

2524

@@ -248,16 +247,11 @@ def __array_ufunc__(self, ufunc, method: str, *inputs, **kwargs):
248247
# ------------------------------------------------------------------------
249248
# Pandas ExtensionArray Interface
250249

251-
def __getitem__(self, item):
252-
if isinstance(item, type(self)):
253-
item = item._ndarray
250+
def _validate_getitem_key(self, key):
251+
if isinstance(key, type(self)):
252+
key = key._ndarray
254253

255-
item = check_array_indexer(self, item)
256-
257-
result = self._ndarray[item]
258-
if not lib.is_scalar(item):
259-
result = type(self)(result)
260-
return result
254+
return super()._validate_getitem_key(key)
261255

262256
def _validate_setitem_value(self, value):
263257
value = extract_array(value, extract_numpy=True)

0 commit comments

Comments
 (0)