Skip to content

ENH: Select numeric ExtensionDtypes with DataFrame.select_dtypes #35341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
40 changes: 40 additions & 0 deletions pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,46 @@ def needs_i8_conversion(arr_or_dtype) -> bool:
)


def np_issubclass_compat(unique_dtype, dtypes_set) -> bool:
"""
Check whether the provided dtype is a subclass of, or has an attribute
(e.g. _is_numeric) indiciating it is a subclass of any of the dtypes in
dtypes_set.

Parameters
----------
unique_dtype : dtype
The dtype to check.
dtypes_set : array-like
The dtypes to check unique_dtype is a sublass of.

Returns
-------
boolean
Whether or not the unique_dtype is a subclass of dtype_set.

Examples
--------
>>> np_issubclass_compat(pd.Int16Dtype(), [np.bool_, np.float])
False
>>> np_issubclass_compat(pd.Int16Dtype(), [np.integer])
True
>>> np_issubclass_compat(pd.BooleanDtype(), [np.bool_])
True
>>> np_issubclass_compat(pd.Float64Dtype(), [np.float])
True
>>> np_issubclass_compat(pd.Float64Dtype(), [np.number])
True
"""
if issubclass(unique_dtype.type, tuple(dtypes_set)) or (
np.number in dtypes_set
and hasattr(unique_dtype, "_is_numeric") # is an extensionarray
and unique_dtype._is_numeric
):
return True
return False


def is_numeric_dtype(arr_or_dtype) -> bool:
"""
Check whether the provided array or dtype is of a numeric dtype.
Expand Down
10 changes: 3 additions & 7 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
is_scalar,
is_sequence,
needs_i8_conversion,
np_issubclass_compat,
pandas_dtype,
)
from pandas.core.dtypes.missing import isna, notna
Expand Down Expand Up @@ -3580,6 +3581,7 @@ def select_dtypes(self, include=None, exclude=None) -> DataFrame:
4 True 1.0
5 False 2.0
"""

if not is_list_like(include):
include = (include,) if include is not None else ()
if not is_list_like(exclude):
Expand Down Expand Up @@ -3610,13 +3612,7 @@ def extract_unique_dtypes_from_dtypes_set(
extracted_dtypes = [
unique_dtype
for unique_dtype in unique_dtypes
# error: Argument 1 to "tuple" has incompatible type
# "FrozenSet[Union[ExtensionDtype, str, Any, Type[str],
# Type[float], Type[int], Type[complex], Type[bool]]]";
# expected "Iterable[Union[type, Tuple[Any, ...]]]"
if issubclass(
unique_dtype.type, tuple(dtypes_set) # type: ignore[arg-type]
)
if np_issubclass_compat(unique_dtype, dtypes_set)
]
return extracted_dtypes

Expand Down
50 changes: 50 additions & 0 deletions pandas/tests/extension/test_select_dtypes_numeric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np

from pandas.core.dtypes.dtypes import ExtensionDtype

import pandas as pd
from pandas.core.arrays import ExtensionArray


class DummyDtype(ExtensionDtype):
type = int

def __init__(self, numeric):
self._numeric = numeric

@property
def name(self):
return "Dummy"

@property
def _is_numeric(self):
return self._numeric


class DummyArray(ExtensionArray):
def __init__(self, data, dtype):
self.data = data
self._dtype = dtype

def __array__(self, dtype):
return self.data

@property
def dtype(self):
return self._dtype

def __len__(self) -> int:
return len(self.data)

def __getitem__(self, item):
pass


def test_select_dtypes_numeric():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of this should go in tests/frame/methods/test_select_dtypes.py

da = DummyArray([1, 2], dtype=DummyDtype(numeric=True))
df = pd.DataFrame(da)
assert df.select_dtypes(np.number).shape == df.shape

da = DummyArray([1, 2], dtype=DummyDtype(numeric=False))
df = pd.DataFrame(da)
assert df.select_dtypes(np.number).shape != df.shape
47 changes: 47 additions & 0 deletions pandas/tests/frame/methods/test_select_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,46 @@
import numpy as np
import pytest

from pandas.core.dtypes.dtypes import ExtensionDtype

import pandas as pd
from pandas import DataFrame, Timestamp
import pandas._testing as tm
from pandas.core.arrays import ExtensionArray


class DummyDtype(ExtensionDtype):
type = int

def __init__(self, numeric):
self._numeric = numeric

@property
def name(self):
return "Dummy"

@property
def _is_numeric(self):
return self._numeric


class DummyArray(ExtensionArray):
def __init__(self, data, dtype):
self.data = data
self._dtype = dtype

def __array__(self, dtype):
return self.data

@property
def dtype(self):
return self._dtype

def __len__(self) -> int:
return len(self.data)

def __getitem__(self, item):
pass


class TestSelectDtypes:
Expand Down Expand Up @@ -324,3 +361,13 @@ def test_select_dtypes_typecodes(self):
expected = df
FLOAT_TYPES = list(np.typecodes["AllFloat"])
tm.assert_frame_equal(df.select_dtypes(FLOAT_TYPES), expected)

def test_select_dtypes_numeric(self):
# GH 35340
da = DummyArray([1, 2], dtype=DummyDtype(numeric=True))
df = pd.DataFrame(da)
assert df.select_dtypes(np.number).shape == df.shape

da = DummyArray([1, 2], dtype=DummyDtype(numeric=False))
df = pd.DataFrame(da)
assert df.select_dtypes(np.number).shape != df.shape