Skip to content

Commit cff780f

Browse files
authored
REF: avoid special-casing in SelectN (#45956)
1 parent dbf9ffe commit cff780f

File tree

3 files changed

+52
-54
lines changed

3 files changed

+52
-54
lines changed

pandas/core/algorithms.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
)
6565
from pandas.core.dtypes.concat import concat_compat
6666
from pandas.core.dtypes.dtypes import (
67+
BaseMaskedDtype,
6768
ExtensionDtype,
6869
PandasDtype,
6970
)
@@ -103,6 +104,7 @@
103104
Series,
104105
)
105106
from pandas.core.arrays import (
107+
BaseMaskedArray,
106108
DatetimeArray,
107109
ExtensionArray,
108110
TimedeltaArray,
@@ -142,6 +144,15 @@ def _ensure_data(values: ArrayLike) -> np.ndarray:
142144
if is_object_dtype(values.dtype):
143145
return ensure_object(np.asarray(values))
144146

147+
elif isinstance(values.dtype, BaseMaskedDtype):
148+
# i.e. BooleanArray, FloatingArray, IntegerArray
149+
values = cast("BaseMaskedArray", values)
150+
if not values._hasna:
151+
# No pd.NAs -> We can avoid an object-dtype cast (and copy) GH#41816
152+
# recurse to avoid re-implementing logic for eg bool->uint8
153+
return _ensure_data(values._data)
154+
return np.asarray(values)
155+
145156
elif is_bool_dtype(values.dtype):
146157
if isinstance(values, np.ndarray):
147158
# i.e. actually dtype == np.dtype("bool")
@@ -1188,18 +1199,6 @@ def compute(self, method: str) -> Series:
11881199
dropped = self.obj.dropna()
11891200
nan_index = self.obj.drop(dropped.index)
11901201

1191-
if is_extension_array_dtype(dropped.dtype):
1192-
# GH#41816 bc we have dropped NAs above, MaskedArrays can use the
1193-
# numpy logic.
1194-
from pandas.core.arrays import BaseMaskedArray
1195-
1196-
arr = dropped._values
1197-
if isinstance(arr, BaseMaskedArray):
1198-
ser = type(dropped)(arr._data, index=dropped.index, name=dropped.name)
1199-
1200-
result = type(self)(ser, n=self.n, keep=self.keep).compute(method)
1201-
return result.astype(arr.dtype)
1202-
12031202
# slow method
12041203
if n >= len(self.obj):
12051204
ascending = method == "nsmallest"

pandas/core/arrays/masked.py

+2-42
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,9 @@
2727
SequenceIndexer,
2828
Shape,
2929
npt,
30-
type_t,
3130
)
3231
from pandas.errors import AbstractMethodError
33-
from pandas.util._decorators import (
34-
cache_readonly,
35-
doc,
36-
)
32+
from pandas.util._decorators import doc
3733
from pandas.util._validators import validate_fillna_kwargs
3834

3935
from pandas.core.dtypes.astype import astype_nansafe
@@ -51,6 +47,7 @@
5147
is_string_dtype,
5248
pandas_dtype,
5349
)
50+
from pandas.core.dtypes.dtypes import BaseMaskedDtype
5451
from pandas.core.dtypes.inference import is_array_like
5552
from pandas.core.dtypes.missing import (
5653
array_equivalent,
@@ -91,43 +88,6 @@
9188
BaseMaskedArrayT = TypeVar("BaseMaskedArrayT", bound="BaseMaskedArray")
9289

9390

94-
class BaseMaskedDtype(ExtensionDtype):
95-
"""
96-
Base class for dtypes for BaseMaskedArray subclasses.
97-
"""
98-
99-
name: str
100-
base = None
101-
type: type
102-
103-
na_value = libmissing.NA
104-
105-
@cache_readonly
106-
def numpy_dtype(self) -> np.dtype:
107-
"""Return an instance of our numpy dtype"""
108-
return np.dtype(self.type)
109-
110-
@cache_readonly
111-
def kind(self) -> str:
112-
return self.numpy_dtype.kind
113-
114-
@cache_readonly
115-
def itemsize(self) -> int:
116-
"""Return the number of bytes in this dtype"""
117-
return self.numpy_dtype.itemsize
118-
119-
@classmethod
120-
def construct_array_type(cls) -> type_t[BaseMaskedArray]:
121-
"""
122-
Return the array type associated with this dtype.
123-
124-
Returns
125-
-------
126-
type
127-
"""
128-
raise NotImplementedError
129-
130-
13191
class BaseMaskedArray(OpsMixin, ExtensionArray):
13292
"""
13393
Base class for masked arrays (which use _data and _mask to store the data).

pandas/core/dtypes/dtypes.py

+39
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import numpy as np
1515
import pytz
1616

17+
from pandas._libs import missing as libmissing
1718
from pandas._libs.interval import Interval
1819
from pandas._libs.properties import cache_readonly
1920
from pandas._libs.tslibs import (
@@ -57,6 +58,7 @@
5758
Index,
5859
)
5960
from pandas.core.arrays import (
61+
BaseMaskedArray,
6062
DatetimeArray,
6163
IntervalArray,
6264
PandasArray,
@@ -1376,3 +1378,40 @@ def itemsize(self) -> int:
13761378
The element size of this data-type object.
13771379
"""
13781380
return self._dtype.itemsize
1381+
1382+
1383+
class BaseMaskedDtype(ExtensionDtype):
1384+
"""
1385+
Base class for dtypes for BaseMaskedArray subclasses.
1386+
"""
1387+
1388+
name: str
1389+
base = None
1390+
type: type
1391+
1392+
na_value = libmissing.NA
1393+
1394+
@cache_readonly
1395+
def numpy_dtype(self) -> np.dtype:
1396+
"""Return an instance of our numpy dtype"""
1397+
return np.dtype(self.type)
1398+
1399+
@cache_readonly
1400+
def kind(self) -> str:
1401+
return self.numpy_dtype.kind
1402+
1403+
@cache_readonly
1404+
def itemsize(self) -> int:
1405+
"""Return the number of bytes in this dtype"""
1406+
return self.numpy_dtype.itemsize
1407+
1408+
@classmethod
1409+
def construct_array_type(cls) -> type_t[BaseMaskedArray]:
1410+
"""
1411+
Return the array type associated with this dtype.
1412+
1413+
Returns
1414+
-------
1415+
type
1416+
"""
1417+
raise NotImplementedError

0 commit comments

Comments
 (0)