diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index a018b4997bf7d..d7feb6e547b22 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -442,6 +442,7 @@ ExtensionType Changes - ``ExtensionArray`` has gained the abstract methods ``.dropna()`` (:issue:`21185`) - ``ExtensionDtype`` has gained the ability to instantiate from string dtypes, e.g. ``decimal`` would instantiate a registered ``DecimalDtype``; furthermore the ``ExtensionDtype`` has gained the method ``construct_array_type`` (:issue:`21185`) +- Added ``ExtensionDtype._is_numeric`` for controlling whether an extension dtype is considered numeric (:issue:`22290`). - The ``ExtensionArray`` constructor, ``_from_sequence`` now take the keyword arg ``copy=False`` (:issue:`21185`) - Bug in :meth:`Series.get` for ``Series`` using ``ExtensionArray`` and integer index (:issue:`21257`) - :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`) diff --git a/pandas/core/arrays/integer.py b/pandas/core/arrays/integer.py index a1ce3541bd5d4..3dffabbe473d3 100644 --- a/pandas/core/arrays/integer.py +++ b/pandas/core/arrays/integer.py @@ -47,6 +47,10 @@ def is_signed_integer(self): def is_unsigned_integer(self): return self.kind == 'u' + @property + def _is_numeric(self): + return True + @cache_readonly def numpy_dtype(self): """ Return an instance of our numpy dtype """ diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index 5f405e0d10657..1ecb6234ad2d9 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -94,6 +94,18 @@ def is_dtype(cls, dtype): except TypeError: return False + @property + def _is_numeric(self): + # type: () -> bool + """ + Whether columns with this dtype should be considered numeric. + + By default ExtensionDtypes are assumed to be non-numeric. + They'll be excluded from operations that exclude non-numeric + columns, like (groupby) reductions, plotting, etc. + """ + return False + class ExtensionDtype(_DtypeOpsMixin): """A custom data type, to be paired with an ExtensionArray. @@ -109,6 +121,11 @@ class ExtensionDtype(_DtypeOpsMixin): * name * construct_from_string + The following attributes influence the behavior of the dtype in + pandas operations + + * _is_numeric + Optionally one can override construct_array_type for construction with the name of this dtype via the Registry diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 0bfc7650a24aa..57d09ff33d8b4 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -665,7 +665,7 @@ def _astype(self, dtype, copy=False, errors='raise', values=None, pass newb = make_block(values, placement=self.mgr_locs, - klass=klass) + klass=klass, ndim=self.ndim) except: if errors == 'raise': raise @@ -1950,6 +1950,10 @@ def is_view(self): """Extension arrays are never treated as views.""" return False + @property + def is_numeric(self): + return self.values.dtype._is_numeric + def setitem(self, indexer, value, mgr=None): """Set the value inplace, returning a same-typed block. diff --git a/pandas/tests/extension/base/groupby.py b/pandas/tests/extension/base/groupby.py index a29ef2a509a63..174997c7d51e1 100644 --- a/pandas/tests/extension/base/groupby.py +++ b/pandas/tests/extension/base/groupby.py @@ -67,3 +67,16 @@ def test_groupby_extension_apply(self, data_for_grouping, op): df.groupby("B").A.apply(op) df.groupby("A").apply(op) df.groupby("A").B.apply(op) + + def test_in_numeric_groupby(self, data_for_grouping): + df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], + "B": data_for_grouping, + "C": [1, 1, 1, 1, 1, 1, 1, 1]}) + result = df.groupby("A").sum().columns + + if data_for_grouping.dtype._is_numeric: + expected = pd.Index(['B', 'C']) + else: + expected = pd.Index(['C']) + + tm.assert_index_equal(result, expected) diff --git a/pandas/tests/extension/base/interface.py b/pandas/tests/extension/base/interface.py index 69de0e1900831..99c3b92541cbd 100644 --- a/pandas/tests/extension/base/interface.py +++ b/pandas/tests/extension/base/interface.py @@ -67,3 +67,7 @@ def test_no_values_attribute(self, data): # code, disallowing this for now until solved assert not hasattr(data, 'values') assert not hasattr(data, '_values') + + def test_is_numeric_honored(self, data): + result = pd.Series(data) + assert result._data.blocks[0].is_numeric is data.dtype._is_numeric diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index f3475dead2418..387942234e6fd 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -44,6 +44,10 @@ def construct_from_string(cls, string): raise TypeError("Cannot construct a '{}' from " "'{}'".format(cls, string)) + @property + def _is_numeric(self): + return True + class DecimalArray(ExtensionArray, ExtensionScalarOpsMixin): diff --git a/pandas/tests/extension/integer/test_integer.py b/pandas/tests/extension/integer/test_integer.py index e7f49c99d9f95..f1c833a68c66c 100644 --- a/pandas/tests/extension/integer/test_integer.py +++ b/pandas/tests/extension/integer/test_integer.py @@ -766,6 +766,22 @@ def test_cross_type_arithmetic(): tm.assert_series_equal(result, expected) +def test_groupby_mean_included(): + df = pd.DataFrame({ + "A": ['a', 'b', 'b'], + "B": [1, None, 3], + "C": IntegerArray([1, None, 3], dtype='Int64'), + }) + + result = df.groupby("A").sum() + # TODO(#22346): preserve Int64 dtype + expected = pd.DataFrame({ + "B": np.array([1.0, 3.0]), + "C": np.array([1, 3], dtype="int64") + }, index=pd.Index(['a', 'b'], name='A')) + tm.assert_frame_equal(result, expected) + + def test_astype_nansafe(): # https://github.com/pandas-dev/pandas/pull/22343 arr = IntegerArray([np.nan, 1, 2], dtype="Int8") diff --git a/pandas/tests/frame/test_block_internals.py b/pandas/tests/frame/test_block_internals.py index 8e012922d25f1..d096daaa0b664 100644 --- a/pandas/tests/frame/test_block_internals.py +++ b/pandas/tests/frame/test_block_internals.py @@ -11,7 +11,8 @@ import numpy as np from pandas import (DataFrame, Series, Timestamp, date_range, compat, - option_context) + option_context, Categorical) +from pandas.core.arrays import IntegerArray, IntervalArray from pandas.compat import StringIO import pandas as pd @@ -436,6 +437,17 @@ def test_get_numeric_data(self): expected = df assert_frame_equal(result, expected) + def test_get_numeric_data_extension_dtype(self): + # GH 22290 + df = DataFrame({ + 'A': IntegerArray([-10, np.nan, 0, 10, 20, 30], dtype='Int64'), + 'B': Categorical(list('abcabc')), + 'C': IntegerArray([0, 1, 2, 3, np.nan, 5], dtype='UInt8'), + 'D': IntervalArray.from_breaks(range(7))}) + result = df._get_numeric_data() + expected = df.loc[:, ['A', 'C']] + assert_frame_equal(result, expected) + def test_convert_objects(self): oops = self.mixed_frame.T.T