Skip to content

Commit 513c02c

Browse files
TomAugspurgerjorisvandenbossche
authored andcommitted
API: ExtensionDtype._is_numeric (#22345)
1 parent b6e35ff commit 513c02c

File tree

9 files changed

+77
-2
lines changed

9 files changed

+77
-2
lines changed

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ ExtensionType Changes
442442
- ``ExtensionArray`` has gained the abstract methods ``.dropna()`` (:issue:`21185`)
443443
- ``ExtensionDtype`` has gained the ability to instantiate from string dtypes, e.g. ``decimal`` would instantiate a registered ``DecimalDtype``; furthermore
444444
the ``ExtensionDtype`` has gained the method ``construct_array_type`` (:issue:`21185`)
445+
- Added ``ExtensionDtype._is_numeric`` for controlling whether an extension dtype is considered numeric (:issue:`22290`).
445446
- The ``ExtensionArray`` constructor, ``_from_sequence`` now take the keyword arg ``copy=False`` (:issue:`21185`)
446447
- Bug in :meth:`Series.get` for ``Series`` using ``ExtensionArray`` and integer index (:issue:`21257`)
447448
- :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`)

pandas/core/arrays/integer.py

+4
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def is_signed_integer(self):
4747
def is_unsigned_integer(self):
4848
return self.kind == 'u'
4949

50+
@property
51+
def _is_numeric(self):
52+
return True
53+
5054
@cache_readonly
5155
def numpy_dtype(self):
5256
""" Return an instance of our numpy dtype """

pandas/core/dtypes/base.py

+17
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,18 @@ def is_dtype(cls, dtype):
9494
except TypeError:
9595
return False
9696

97+
@property
98+
def _is_numeric(self):
99+
# type: () -> bool
100+
"""
101+
Whether columns with this dtype should be considered numeric.
102+
103+
By default ExtensionDtypes are assumed to be non-numeric.
104+
They'll be excluded from operations that exclude non-numeric
105+
columns, like (groupby) reductions, plotting, etc.
106+
"""
107+
return False
108+
97109

98110
class ExtensionDtype(_DtypeOpsMixin):
99111
"""A custom data type, to be paired with an ExtensionArray.
@@ -109,6 +121,11 @@ class ExtensionDtype(_DtypeOpsMixin):
109121
* name
110122
* construct_from_string
111123
124+
The following attributes influence the behavior of the dtype in
125+
pandas operations
126+
127+
* _is_numeric
128+
112129
Optionally one can override construct_array_type for construction
113130
with the name of this dtype via the Registry
114131

pandas/core/internals/blocks.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def _astype(self, dtype, copy=False, errors='raise', values=None,
665665
pass
666666

667667
newb = make_block(values, placement=self.mgr_locs,
668-
klass=klass)
668+
klass=klass, ndim=self.ndim)
669669
except:
670670
if errors == 'raise':
671671
raise
@@ -1950,6 +1950,10 @@ def is_view(self):
19501950
"""Extension arrays are never treated as views."""
19511951
return False
19521952

1953+
@property
1954+
def is_numeric(self):
1955+
return self.values.dtype._is_numeric
1956+
19531957
def setitem(self, indexer, value, mgr=None):
19541958
"""Set the value inplace, returning a same-typed block.
19551959

pandas/tests/extension/base/groupby.py

+13
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,16 @@ def test_groupby_extension_apply(self, data_for_grouping, op):
6767
df.groupby("B").A.apply(op)
6868
df.groupby("A").apply(op)
6969
df.groupby("A").B.apply(op)
70+
71+
def test_in_numeric_groupby(self, data_for_grouping):
72+
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4],
73+
"B": data_for_grouping,
74+
"C": [1, 1, 1, 1, 1, 1, 1, 1]})
75+
result = df.groupby("A").sum().columns
76+
77+
if data_for_grouping.dtype._is_numeric:
78+
expected = pd.Index(['B', 'C'])
79+
else:
80+
expected = pd.Index(['C'])
81+
82+
tm.assert_index_equal(result, expected)

pandas/tests/extension/base/interface.py

+4
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,7 @@ def test_no_values_attribute(self, data):
6767
# code, disallowing this for now until solved
6868
assert not hasattr(data, 'values')
6969
assert not hasattr(data, '_values')
70+
71+
def test_is_numeric_honored(self, data):
72+
result = pd.Series(data)
73+
assert result._data.blocks[0].is_numeric is data.dtype._is_numeric

pandas/tests/extension/decimal/array.py

+4
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ def construct_from_string(cls, string):
4444
raise TypeError("Cannot construct a '{}' from "
4545
"'{}'".format(cls, string))
4646

47+
@property
48+
def _is_numeric(self):
49+
return True
50+
4751

4852
class DecimalArray(ExtensionArray, ExtensionScalarOpsMixin):
4953

pandas/tests/extension/integer/test_integer.py

+16
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,22 @@ def test_cross_type_arithmetic():
766766
tm.assert_series_equal(result, expected)
767767

768768

769+
def test_groupby_mean_included():
770+
df = pd.DataFrame({
771+
"A": ['a', 'b', 'b'],
772+
"B": [1, None, 3],
773+
"C": IntegerArray([1, None, 3], dtype='Int64'),
774+
})
775+
776+
result = df.groupby("A").sum()
777+
# TODO(#22346): preserve Int64 dtype
778+
expected = pd.DataFrame({
779+
"B": np.array([1.0, 3.0]),
780+
"C": np.array([1, 3], dtype="int64")
781+
}, index=pd.Index(['a', 'b'], name='A'))
782+
tm.assert_frame_equal(result, expected)
783+
784+
769785
def test_astype_nansafe():
770786
# https://github.com/pandas-dev/pandas/pull/22343
771787
arr = IntegerArray([np.nan, 1, 2], dtype="Int8")

pandas/tests/frame/test_block_internals.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import numpy as np
1212

1313
from pandas import (DataFrame, Series, Timestamp, date_range, compat,
14-
option_context)
14+
option_context, Categorical)
15+
from pandas.core.arrays import IntegerArray, IntervalArray
1516
from pandas.compat import StringIO
1617
import pandas as pd
1718

@@ -436,6 +437,17 @@ def test_get_numeric_data(self):
436437
expected = df
437438
assert_frame_equal(result, expected)
438439

440+
def test_get_numeric_data_extension_dtype(self):
441+
# GH 22290
442+
df = DataFrame({
443+
'A': IntegerArray([-10, np.nan, 0, 10, 20, 30], dtype='Int64'),
444+
'B': Categorical(list('abcabc')),
445+
'C': IntegerArray([0, 1, 2, 3, np.nan, 5], dtype='UInt8'),
446+
'D': IntervalArray.from_breaks(range(7))})
447+
result = df._get_numeric_data()
448+
expected = df.loc[:, ['A', 'C']]
449+
assert_frame_equal(result, expected)
450+
439451
def test_convert_objects(self):
440452

441453
oops = self.mixed_frame.T.T

0 commit comments

Comments
 (0)