Skip to content

ENH: Select numeric ExtensionDtypes with DataFrame.select_dtypes #35341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
12 changes: 12 additions & 0 deletions pandas/compat/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np


# numpy versioning
_np_version = np.__version__
_nlv = LooseVersion(_np_version)
Expand Down Expand Up @@ -62,6 +63,17 @@ def np_array_datetime64_compat(arr, *args, **kwargs):
return np.array(arr, *args, **kwargs)


def np_issubclass_compat(unique_dtype, dtypes_set):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this wouldn't g in here, better in pandas/core/dtypes/common.py

with a full doc-string & examples

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is almost is_numeric, but with the twist about np.number

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved it and added doc string and examples. I'm not aware of any pandas dtypes that have _is_numeric = True and don't inherit from np.number.

Should is_numeric return True for ExtensionDtypes that have _is_numeric = True and don't inherit from np.number?

if (issubclass(unique_dtype.type, tuple(dtypes_set)) # type: ignore
or (
np.number in dtypes_set
and hasattr(unique_dtype, "_is_numeric" # is an extensionarray
and unique_dtype._is_numeric
)):
return True
return False


__all__ = [
"np",
"_np_version",
Expand Down
10 changes: 3 additions & 7 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
)
from pandas.compat._optional import import_optional_dependency
from pandas.compat.numpy import function as nv
from pandas.compat.numpy import np_issubclass_compat
from pandas.util._decorators import (
Appender,
Substitution,
Expand Down Expand Up @@ -3580,6 +3581,7 @@ def select_dtypes(self, include=None, exclude=None) -> DataFrame:
4 True 1.0
5 False 2.0
"""

if not is_list_like(include):
include = (include,) if include is not None else ()
if not is_list_like(exclude):
Expand Down Expand Up @@ -3610,13 +3612,7 @@ def extract_unique_dtypes_from_dtypes_set(
extracted_dtypes = [
unique_dtype
for unique_dtype in unique_dtypes
# error: Argument 1 to "tuple" has incompatible type
# "FrozenSet[Union[ExtensionDtype, str, Any, Type[str],
# Type[float], Type[int], Type[complex], Type[bool]]]";
# expected "Iterable[Union[type, Tuple[Any, ...]]]"
if issubclass(
unique_dtype.type, tuple(dtypes_set) # type: ignore[arg-type]
)
if np_issubclass_compat(unique_dtype, dtypes_set)
]
return extracted_dtypes

Expand Down
50 changes: 50 additions & 0 deletions pandas/tests/extension/test_select_dtypes_numeric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np

import pandas as pd
from pandas.core.arrays import ExtensionArray
from pandas.core.dtypes.dtypes import ExtensionDtype


class DummyDtype(ExtensionDtype):
type = int
_numeric = False

@property
def name(self):
return "Dummy"

@property
def _is_numeric(self):
return self._numeric


class DummyArray(ExtensionArray):
_dtype = DummyDtype()

def __init__(self, data):
self.data = data

def __array__(self, dtype):
return self.data

@property
def dtype(self):
return self._dtype

def __len__(self) -> int:
return len(self.data)

def __getitem__(self, item):
pass


def test_select_dtypes_numeric():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of this should go in tests/frame/methods/test_select_dtypes.py

da = DummyArray([1, 2])
da._dtype._numeric = True
df = pd.DataFrame(da)
assert df.select_dtypes(np.number).shape == df.shape

da = DummyArray([1, 2])
da._dtype._numeric = False
df = pd.DataFrame(da)
assert df.select_dtypes(np.number).shape != df.shape