diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index 14184f044ae95..2469f2a1927ed 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -1231,6 +1231,46 @@ def needs_i8_conversion(arr_or_dtype) -> bool: ) +def np_issubclass_compat(unique_dtype, dtypes_set) -> bool: + """ + Check whether the provided dtype is a subclass of, or has an attribute + (e.g. _is_numeric) indiciating it is a subclass of any of the dtypes in + dtypes_set. + + Parameters + ---------- + unique_dtype : dtype + The dtype to check. + dtypes_set : array-like + The dtypes to check unique_dtype is a sublass of. + + Returns + ------- + boolean + Whether or not the unique_dtype is a subclass of dtype_set. + + Examples + -------- + >>> np_issubclass_compat(pd.Int16Dtype(), [np.bool_, np.float]) + False + >>> np_issubclass_compat(pd.Int16Dtype(), [np.integer]) + True + >>> np_issubclass_compat(pd.BooleanDtype(), [np.bool_]) + True + >>> np_issubclass_compat(pd.Float64Dtype(), [np.float]) + True + >>> np_issubclass_compat(pd.Float64Dtype(), [np.number]) + True + """ + if issubclass(unique_dtype.type, tuple(dtypes_set)) or ( + np.number in dtypes_set + and hasattr(unique_dtype, "_is_numeric") # is an extensionarray + and unique_dtype._is_numeric + ): + return True + return False + + def is_numeric_dtype(arr_or_dtype) -> bool: """ Check whether the provided array or dtype is of a numeric dtype. diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 0314bdc4ee8ed..50855769f17ec 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -112,6 +112,7 @@ is_scalar, is_sequence, needs_i8_conversion, + np_issubclass_compat, pandas_dtype, ) from pandas.core.dtypes.missing import isna, notna @@ -3580,6 +3581,7 @@ def select_dtypes(self, include=None, exclude=None) -> DataFrame: 4 True 1.0 5 False 2.0 """ + if not is_list_like(include): include = (include,) if include is not None else () if not is_list_like(exclude): @@ -3610,13 +3612,7 @@ def extract_unique_dtypes_from_dtypes_set( extracted_dtypes = [ unique_dtype for unique_dtype in unique_dtypes - # error: Argument 1 to "tuple" has incompatible type - # "FrozenSet[Union[ExtensionDtype, str, Any, Type[str], - # Type[float], Type[int], Type[complex], Type[bool]]]"; - # expected "Iterable[Union[type, Tuple[Any, ...]]]" - if issubclass( - unique_dtype.type, tuple(dtypes_set) # type: ignore[arg-type] - ) + if np_issubclass_compat(unique_dtype, dtypes_set) ] return extracted_dtypes diff --git a/pandas/tests/extension/test_select_dtypes_numeric.py b/pandas/tests/extension/test_select_dtypes_numeric.py new file mode 100644 index 0000000000000..b65ef9f3b2b03 --- /dev/null +++ b/pandas/tests/extension/test_select_dtypes_numeric.py @@ -0,0 +1,50 @@ +import numpy as np + +from pandas.core.dtypes.dtypes import ExtensionDtype + +import pandas as pd +from pandas.core.arrays import ExtensionArray + + +class DummyDtype(ExtensionDtype): + type = int + + def __init__(self, numeric): + self._numeric = numeric + + @property + def name(self): + return "Dummy" + + @property + def _is_numeric(self): + return self._numeric + + +class DummyArray(ExtensionArray): + def __init__(self, data, dtype): + self.data = data + self._dtype = dtype + + def __array__(self, dtype): + return self.data + + @property + def dtype(self): + return self._dtype + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, item): + pass + + +def test_select_dtypes_numeric(): + da = DummyArray([1, 2], dtype=DummyDtype(numeric=True)) + df = pd.DataFrame(da) + assert df.select_dtypes(np.number).shape == df.shape + + da = DummyArray([1, 2], dtype=DummyDtype(numeric=False)) + df = pd.DataFrame(da) + assert df.select_dtypes(np.number).shape != df.shape diff --git a/pandas/tests/frame/methods/test_select_dtypes.py b/pandas/tests/frame/methods/test_select_dtypes.py index 4599761909c33..509c5c754af34 100644 --- a/pandas/tests/frame/methods/test_select_dtypes.py +++ b/pandas/tests/frame/methods/test_select_dtypes.py @@ -1,9 +1,46 @@ import numpy as np import pytest +from pandas.core.dtypes.dtypes import ExtensionDtype + import pandas as pd from pandas import DataFrame, Timestamp import pandas._testing as tm +from pandas.core.arrays import ExtensionArray + + +class DummyDtype(ExtensionDtype): + type = int + + def __init__(self, numeric): + self._numeric = numeric + + @property + def name(self): + return "Dummy" + + @property + def _is_numeric(self): + return self._numeric + + +class DummyArray(ExtensionArray): + def __init__(self, data, dtype): + self.data = data + self._dtype = dtype + + def __array__(self, dtype): + return self.data + + @property + def dtype(self): + return self._dtype + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, item): + pass class TestSelectDtypes: @@ -324,3 +361,13 @@ def test_select_dtypes_typecodes(self): expected = df FLOAT_TYPES = list(np.typecodes["AllFloat"]) tm.assert_frame_equal(df.select_dtypes(FLOAT_TYPES), expected) + + def test_select_dtypes_numeric(self): + # GH 35340 + da = DummyArray([1, 2], dtype=DummyDtype(numeric=True)) + df = pd.DataFrame(da) + assert df.select_dtypes(np.number).shape == df.shape + + da = DummyArray([1, 2], dtype=DummyDtype(numeric=False)) + df = pd.DataFrame(da) + assert df.select_dtypes(np.number).shape != df.shape