diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 05d3b1c797375..020d3091929db 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -174,7 +174,7 @@ Timezones Numeric ^^^^^^^ - Bug in :meth:`DataFrame.quantile`, :meth:`DataFrame.sort_values` causing incorrect subsequent indexing behavior (:issue:`38351`) -- +- Bug in :meth:`DataFrame.select_dtypes` with ``include=np.number`` now retains numeric ``ExtensionDtype`` columns (:issue:`35340`) - Conversion diff --git a/pandas/core/frame.py b/pandas/core/frame.py index d698edbb0b8ad..d123e0735ded1 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -3719,12 +3719,14 @@ def extract_unique_dtypes_from_dtypes_set( extracted_dtypes = [ unique_dtype for unique_dtype in unique_dtypes - # error: Argument 1 to "tuple" has incompatible type - # "FrozenSet[Union[ExtensionDtype, str, Any, Type[str], - # Type[float], Type[int], Type[complex], Type[bool]]]"; - # expected "Iterable[Union[type, Tuple[Any, ...]]]" - if issubclass( - unique_dtype.type, tuple(dtypes_set) # type: ignore[arg-type] + if ( + issubclass( + unique_dtype.type, tuple(dtypes_set) # type: ignore[arg-type] + ) + or ( + np.number in dtypes_set + and getattr(unique_dtype, "_is_numeric", False) + ) ) ] return extracted_dtypes diff --git a/pandas/tests/frame/methods/test_select_dtypes.py b/pandas/tests/frame/methods/test_select_dtypes.py index 2a8826cedd50a..f2dbe4a799a17 100644 --- a/pandas/tests/frame/methods/test_select_dtypes.py +++ b/pandas/tests/frame/methods/test_select_dtypes.py @@ -1,9 +1,46 @@ import numpy as np import pytest +from pandas.core.dtypes.dtypes import ExtensionDtype + import pandas as pd from pandas import DataFrame, Timestamp import pandas._testing as tm +from pandas.core.arrays import ExtensionArray + + +class DummyDtype(ExtensionDtype): + type = int + + def __init__(self, numeric): + self._numeric = numeric + + @property + def name(self): + return "Dummy" + + @property + def _is_numeric(self): + return self._numeric + + +class DummyArray(ExtensionArray): + def __init__(self, data, dtype): + self.data = data + self._dtype = dtype + + def __array__(self, dtype): + return self.data + + @property + def dtype(self): + return self._dtype + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, item): + pass class TestSelectDtypes: @@ -322,3 +359,20 @@ def test_select_dtypes_typecodes(self): expected = df FLOAT_TYPES = list(np.typecodes["AllFloat"]) tm.assert_frame_equal(df.select_dtypes(FLOAT_TYPES), expected) + + @pytest.mark.parametrize( + "arr,expected", + ( + (np.array([1, 2], dtype=np.int32), True), + (pd.array([1, 2], dtype="Int32"), True), + (pd.array(["a", "b"], dtype="string"), False), + (DummyArray([1, 2], dtype=DummyDtype(numeric=True)), True), + (DummyArray([1, 2], dtype=DummyDtype(numeric=False)), False), + ), + ) + def test_select_dtypes_numeric(self, arr, expected): + # GH 35340 + + df = DataFrame(arr) + is_selected = df.select_dtypes(np.number).shape == df.shape + assert is_selected == expected