Skip to content

Commit c0deed1

Browse files
arw2019luckyvs1
authored andcommitted
BUG: Allow numeric ExtensionDtypes in DataFrame.select_dtypes (pandas-dev#38246)
1 parent a8d1436 commit c0deed1

File tree

3 files changed

+63
-7
lines changed

3 files changed

+63
-7
lines changed

doc/source/whatsnew/v1.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ Timezones
174174
Numeric
175175
^^^^^^^
176176
- Bug in :meth:`DataFrame.quantile`, :meth:`DataFrame.sort_values` causing incorrect subsequent indexing behavior (:issue:`38351`)
177-
-
177+
- Bug in :meth:`DataFrame.select_dtypes` with ``include=np.number`` now retains numeric ``ExtensionDtype`` columns (:issue:`35340`)
178178
-
179179

180180
Conversion

pandas/core/frame.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -3719,12 +3719,14 @@ def extract_unique_dtypes_from_dtypes_set(
37193719
extracted_dtypes = [
37203720
unique_dtype
37213721
for unique_dtype in unique_dtypes
3722-
# error: Argument 1 to "tuple" has incompatible type
3723-
# "FrozenSet[Union[ExtensionDtype, str, Any, Type[str],
3724-
# Type[float], Type[int], Type[complex], Type[bool]]]";
3725-
# expected "Iterable[Union[type, Tuple[Any, ...]]]"
3726-
if issubclass(
3727-
unique_dtype.type, tuple(dtypes_set) # type: ignore[arg-type]
3722+
if (
3723+
issubclass(
3724+
unique_dtype.type, tuple(dtypes_set) # type: ignore[arg-type]
3725+
)
3726+
or (
3727+
np.number in dtypes_set
3728+
and getattr(unique_dtype, "_is_numeric", False)
3729+
)
37283730
)
37293731
]
37303732
return extracted_dtypes

pandas/tests/frame/methods/test_select_dtypes.py

+54
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,46 @@
11
import numpy as np
22
import pytest
33

4+
from pandas.core.dtypes.dtypes import ExtensionDtype
5+
46
import pandas as pd
57
from pandas import DataFrame, Timestamp
68
import pandas._testing as tm
9+
from pandas.core.arrays import ExtensionArray
10+
11+
12+
class DummyDtype(ExtensionDtype):
13+
type = int
14+
15+
def __init__(self, numeric):
16+
self._numeric = numeric
17+
18+
@property
19+
def name(self):
20+
return "Dummy"
21+
22+
@property
23+
def _is_numeric(self):
24+
return self._numeric
25+
26+
27+
class DummyArray(ExtensionArray):
28+
def __init__(self, data, dtype):
29+
self.data = data
30+
self._dtype = dtype
31+
32+
def __array__(self, dtype):
33+
return self.data
34+
35+
@property
36+
def dtype(self):
37+
return self._dtype
38+
39+
def __len__(self) -> int:
40+
return len(self.data)
41+
42+
def __getitem__(self, item):
43+
pass
744

845

946
class TestSelectDtypes:
@@ -322,3 +359,20 @@ def test_select_dtypes_typecodes(self):
322359
expected = df
323360
FLOAT_TYPES = list(np.typecodes["AllFloat"])
324361
tm.assert_frame_equal(df.select_dtypes(FLOAT_TYPES), expected)
362+
363+
@pytest.mark.parametrize(
364+
"arr,expected",
365+
(
366+
(np.array([1, 2], dtype=np.int32), True),
367+
(pd.array([1, 2], dtype="Int32"), True),
368+
(pd.array(["a", "b"], dtype="string"), False),
369+
(DummyArray([1, 2], dtype=DummyDtype(numeric=True)), True),
370+
(DummyArray([1, 2], dtype=DummyDtype(numeric=False)), False),
371+
),
372+
)
373+
def test_select_dtypes_numeric(self, arr, expected):
374+
# GH 35340
375+
376+
df = DataFrame(arr)
377+
is_selected = df.select_dtypes(np.number).shape == df.shape
378+
assert is_selected == expected

0 commit comments

Comments
 (0)