Skip to content

Commit b843388

Browse files
jrebacktm9k1
authored andcommitted
ENH: add groupby & reduce support to EA (pandas-dev#22762)
1 parent 19ca934 commit b843388

18 files changed

+269
-31
lines changed

doc/source/whatsnew/v0.24.0.txt

+9-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Pandas has gained the ability to hold integer dtypes with missing values. This l
4848
Here is an example of the usage.
4949

5050
We can construct a ``Series`` with the specified dtype. The dtype string ``Int64`` is a pandas ``ExtensionDtype``. Specifying a list or array using the traditional missing value
51-
marker of ``np.nan`` will infer to integer dtype. The display of the ``Series`` will also use the ``NaN`` to indicate missing values in string outputs. (:issue:`20700`, :issue:`20747`, :issue:`22441`)
51+
marker of ``np.nan`` will infer to integer dtype. The display of the ``Series`` will also use the ``NaN`` to indicate missing values in string outputs. (:issue:`20700`, :issue:`20747`, :issue:`22441`, :issue:`21789`, :issue:`22346`)
5252

5353
.. ipython:: python
5454

@@ -91,6 +91,13 @@ These dtypes can be merged & reshaped & casted.
9191
pd.concat([df[['A']], df[['B', 'C']]], axis=1).dtypes
9292
df['A'].astype(float)
9393

94+
Reduction and groupby operations such as 'sum' work.
95+
96+
.. ipython:: python
97+
98+
df.sum()
99+
df.groupby('B').A.sum()
100+
94101
.. warning::
95102

96103
The Integer NA support currently uses the captilized dtype version, e.g. ``Int8`` as compared to the traditional ``int8``. This may be changed at a future date.
@@ -567,6 +574,7 @@ update the ``ExtensionDtype._metadata`` tuple to match the signature of your
567574
- Added :meth:`pandas.api.types.register_extension_dtype` to register an extension type with pandas (:issue:`22664`)
568575
- Series backed by an ``ExtensionArray`` now work with :func:`util.hash_pandas_object` (:issue:`23066`)
569576
- Updated the ``.type`` attribute for ``PeriodDtype``, ``DatetimeTZDtype``, and ``IntervalDtype`` to be instances of the dtype (``Period``, ``Timestamp``, and ``Interval`` respectively) (:issue:`22938`)
577+
- Support for reduction operations such as ``sum``, ``mean`` via opt-in base class method override (:issue:`22762`)
570578

571579
.. _whatsnew_0240.api.incompatibilities:
572580

pandas/conftest.py

+24
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,30 @@ def all_arithmetic_operators(request):
131131
return request.param
132132

133133

134+
_all_numeric_reductions = ['sum', 'max', 'min',
135+
'mean', 'prod', 'std', 'var', 'median',
136+
'kurt', 'skew']
137+
138+
139+
@pytest.fixture(params=_all_numeric_reductions)
140+
def all_numeric_reductions(request):
141+
"""
142+
Fixture for numeric reduction names
143+
"""
144+
return request.param
145+
146+
147+
_all_boolean_reductions = ['all', 'any']
148+
149+
150+
@pytest.fixture(params=_all_boolean_reductions)
151+
def all_boolean_reductions(request):
152+
"""
153+
Fixture for boolean reduction names
154+
"""
155+
return request.param
156+
157+
134158
_cython_table = pd.core.base.SelectionMixin._cython_table.items()
135159

136160

pandas/core/arrays/base.py

+31
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ class ExtensionArray(object):
6363
as they only compose abstract methods. Still, a more efficient
6464
implementation may be available, and these methods can be overridden.
6565
66+
One can implement methods to handle array reductions.
67+
68+
* _reduce
69+
6670
This class does not inherit from 'abc.ABCMeta' for performance reasons.
6771
Methods and properties required by the interface raise
6872
``pandas.errors.AbstractMethodError`` and no ``register`` method is
@@ -675,6 +679,33 @@ def _ndarray_values(self):
675679
"""
676680
return np.array(self)
677681

682+
def _reduce(self, name, skipna=True, **kwargs):
683+
"""
684+
Return a scalar result of performing the reduction operation.
685+
686+
Parameters
687+
----------
688+
name : str
689+
Name of the function, supported values are:
690+
{ any, all, min, max, sum, mean, median, prod,
691+
std, var, sem, kurt, skew }.
692+
skipna : bool, default True
693+
If True, skip NaN values.
694+
**kwargs
695+
Additional keyword arguments passed to the reduction function.
696+
Currently, `ddof` is the only supported kwarg.
697+
698+
Returns
699+
-------
700+
scalar
701+
702+
Raises
703+
------
704+
TypeError : subclass does not define reductions
705+
"""
706+
raise TypeError("cannot perform {name} with type {dtype}".format(
707+
name=name, dtype=self.dtype))
708+
678709

679710
class ExtensionOpsMixin(object):
680711
"""

pandas/core/arrays/categorical.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -2069,14 +2069,12 @@ def _reverse_indexer(self):
20692069
return result
20702070

20712071
# reduction ops #
2072-
def _reduce(self, op, name, axis=0, skipna=True, numeric_only=None,
2073-
filter_type=None, **kwds):
2074-
""" perform the reduction type operation """
2072+
def _reduce(self, name, axis=0, skipna=True, **kwargs):
20752073
func = getattr(self, name, None)
20762074
if func is None:
20772075
msg = 'Categorical cannot perform the operation {op}'
20782076
raise TypeError(msg.format(op=name))
2079-
return func(numeric_only=numeric_only, **kwds)
2077+
return func(**kwargs)
20802078

20812079
def min(self, numeric_only=None, **kwargs):
20822080
""" The minimum value of the object.

pandas/core/arrays/integer.py

+26
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pandas.compat import u, range, string_types
99
from pandas.compat import set_function_name
1010

11+
from pandas.core import nanops
1112
from pandas.core.dtypes.cast import astype_nansafe
1213
from pandas.core.dtypes.generic import ABCSeries, ABCIndexClass
1314
from pandas.core.dtypes.common import (
@@ -529,6 +530,31 @@ def cmp_method(self, other):
529530
name = '__{name}__'.format(name=op.__name__)
530531
return set_function_name(cmp_method, name, cls)
531532

533+
def _reduce(self, name, skipna=True, **kwargs):
534+
data = self._data
535+
mask = self._mask
536+
537+
# coerce to a nan-aware float if needed
538+
if mask.any():
539+
data = self._data.astype('float64')
540+
data[mask] = self._na_value
541+
542+
op = getattr(nanops, 'nan' + name)
543+
result = op(data, axis=0, skipna=skipna, mask=mask)
544+
545+
# if we have a boolean op, don't coerce
546+
if name in ['any', 'all']:
547+
pass
548+
549+
# if we have a preservable numeric op,
550+
# provide coercion back to an integer type if possible
551+
elif name in ['sum', 'min', 'max', 'prod'] and notna(result):
552+
int_result = int(result)
553+
if int_result == result:
554+
result = int_result
555+
556+
return result
557+
532558
def _maybe_mask_result(self, result, mask, other, op_name):
533559
"""
534560
Parameters

pandas/core/series.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -3392,16 +3392,25 @@ def _reduce(self, op, name, axis=0, skipna=True, numeric_only=None,
33923392
33933393
"""
33943394
delegate = self._values
3395-
if isinstance(delegate, np.ndarray):
3396-
# Validate that 'axis' is consistent with Series's single axis.
3397-
if axis is not None:
3398-
self._get_axis_number(axis)
3395+
3396+
if axis is not None:
3397+
self._get_axis_number(axis)
3398+
3399+
# dispatch to ExtensionArray interface
3400+
if isinstance(delegate, ExtensionArray):
3401+
return delegate._reduce(name, skipna=skipna, **kwds)
3402+
3403+
# dispatch to numpy arrays
3404+
elif isinstance(delegate, np.ndarray):
33993405
if numeric_only:
34003406
raise NotImplementedError('Series.{0} does not implement '
34013407
'numeric_only.'.format(name))
34023408
with np.errstate(all='ignore'):
34033409
return op(delegate, skipna=skipna, **kwds)
34043410

3411+
# TODO(EA) dispatch to Index
3412+
# remove once all internals extension types are
3413+
# moved to ExtensionArrays
34053414
return delegate._reduce(op=op, name=name, axis=axis, skipna=skipna,
34063415
numeric_only=numeric_only,
34073416
filter_type=filter_type, **kwds)

pandas/tests/arrays/test_integer.py

+41-4
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ def _check_op(self, s, op_name, other, exc=None):
114114
# compute expected
115115
mask = s.isna()
116116

117+
# if s is a DataFrame, squeeze to a Series
118+
# for comparison
119+
if isinstance(s, pd.DataFrame):
120+
result = result.squeeze()
121+
s = s.squeeze()
122+
mask = mask.squeeze()
123+
117124
# other array is an Integer
118125
if isinstance(other, IntegerArray):
119126
omask = getattr(other, 'mask', None)
@@ -215,7 +222,6 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
215222
s = pd.Series(data)
216223
self._check_op(s, op, 1, exc=TypeError)
217224

218-
@pytest.mark.xfail(run=False, reason="_reduce needs implementation")
219225
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
220226
# frame & scalar
221227
op = all_arithmetic_operators
@@ -587,22 +593,53 @@ def test_cross_type_arithmetic():
587593
tm.assert_series_equal(result, expected)
588594

589595

590-
def test_groupby_mean_included():
596+
@pytest.mark.parametrize('op', ['sum', 'min', 'max', 'prod'])
597+
def test_preserve_dtypes(op):
598+
# TODO(#22346): preserve Int64 dtype
599+
# for ops that enable (mean would actually work here
600+
# but generally it is a float return value)
591601
df = pd.DataFrame({
592602
"A": ['a', 'b', 'b'],
593603
"B": [1, None, 3],
594604
"C": integer_array([1, None, 3], dtype='Int64'),
595605
})
596606

597-
result = df.groupby("A").sum()
598-
# TODO(#22346): preserve Int64 dtype
607+
# op
608+
result = getattr(df.C, op)()
609+
assert isinstance(result, int)
610+
611+
# groupby
612+
result = getattr(df.groupby("A"), op)()
599613
expected = pd.DataFrame({
600614
"B": np.array([1.0, 3.0]),
601615
"C": np.array([1, 3], dtype="int64")
602616
}, index=pd.Index(['a', 'b'], name='A'))
603617
tm.assert_frame_equal(result, expected)
604618

605619

620+
@pytest.mark.parametrize('op', ['mean'])
621+
def test_reduce_to_float(op):
622+
# some reduce ops always return float, even if the result
623+
# is a rounded number
624+
df = pd.DataFrame({
625+
"A": ['a', 'b', 'b'],
626+
"B": [1, None, 3],
627+
"C": integer_array([1, None, 3], dtype='Int64'),
628+
})
629+
630+
# op
631+
result = getattr(df.C, op)()
632+
assert isinstance(result, float)
633+
634+
# groupby
635+
result = getattr(df.groupby("A"), op)()
636+
expected = pd.DataFrame({
637+
"B": np.array([1.0, 3.0]),
638+
"C": np.array([1, 3], dtype="float64")
639+
}, index=pd.Index(['a', 'b'], name='A'))
640+
tm.assert_frame_equal(result, expected)
641+
642+
606643
def test_astype_nansafe():
607644
# https://github.com/pandas-dev/pandas/pull/22343
608645
arr = integer_array([np.nan, 1, 2], dtype="Int8")

pandas/tests/dtypes/test_common.py

+2
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,8 @@ def test_is_datetime_or_timedelta_dtype():
386386
assert not com.is_datetime_or_timedelta_dtype(str)
387387
assert not com.is_datetime_or_timedelta_dtype(pd.Series([1, 2]))
388388
assert not com.is_datetime_or_timedelta_dtype(np.array(['a', 'b']))
389+
assert not com.is_datetime_or_timedelta_dtype(
390+
DatetimeTZDtype("ns", "US/Eastern"))
389391

390392
assert com.is_datetime_or_timedelta_dtype(np.datetime64)
391393
assert com.is_datetime_or_timedelta_dtype(np.timedelta64)

pandas/tests/extension/arrow/test_bool.py

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def test_from_dtype(self, data):
3939
pytest.skip("GH-22666")
4040

4141

42+
class TestReduce(base.BaseNoReduceTests):
43+
pass
44+
45+
4246
def test_is_bool_dtype(data):
4347
assert pd.api.types.is_bool_dtype(data)
4448
assert pd.core.common.is_bool_indexer(data)

pandas/tests/extension/base/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class TestMyDtype(BaseDtypeTests):
4848
from .interface import BaseInterfaceTests # noqa
4949
from .methods import BaseMethodsTests # noqa
5050
from .ops import BaseArithmeticOpsTests, BaseComparisonOpsTests, BaseOpsUtil # noqa
51+
from .reduce import BaseNoReduceTests, BaseNumericReduceTests, BaseBooleanReduceTests # noqa
5152
from .missing import BaseMissingTests # noqa
5253
from .reshaping import BaseReshapingTests # noqa
5354
from .setitem import BaseSetitemTests # noqa

pandas/tests/extension/base/groupby.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping):
2525
"B": data_for_grouping})
2626
result = df.groupby("B", as_index=as_index).A.mean()
2727
_, index = pd.factorize(data_for_grouping, sort=True)
28-
# TODO(ExtensionIndex): remove astype
29-
index = pd.Index(index.astype(object), name="B")
28+
29+
index = pd.Index(index, name="B")
3030
expected = pd.Series([3, 1, 4], index=index, name="A")
3131
if as_index:
3232
self.assert_series_equal(result, expected)
@@ -39,8 +39,8 @@ def test_groupby_extension_no_sort(self, data_for_grouping):
3939
"B": data_for_grouping})
4040
result = df.groupby("B", sort=False).A.mean()
4141
_, index = pd.factorize(data_for_grouping, sort=False)
42-
# TODO(ExtensionIndex): remove astype
43-
index = pd.Index(index.astype(object), name="B")
42+
43+
index = pd.Index(index, name="B")
4444
expected = pd.Series([1, 3, 4], index=index, name="A")
4545
self.assert_series_equal(result, expected)
4646

pandas/tests/extension/base/reduce.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import warnings
2+
import pytest
3+
import pandas.util.testing as tm
4+
import pandas as pd
5+
from .base import BaseExtensionTests
6+
7+
8+
class BaseReduceTests(BaseExtensionTests):
9+
"""
10+
Reduction specific tests. Generally these only
11+
make sense for numeric/boolean operations.
12+
"""
13+
def check_reduce(self, s, op_name, skipna):
14+
result = getattr(s, op_name)(skipna=skipna)
15+
expected = getattr(s.astype('float64'), op_name)(skipna=skipna)
16+
tm.assert_almost_equal(result, expected)
17+
18+
19+
class BaseNoReduceTests(BaseReduceTests):
20+
""" we don't define any reductions """
21+
22+
@pytest.mark.parametrize('skipna', [True, False])
23+
def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna):
24+
op_name = all_numeric_reductions
25+
s = pd.Series(data)
26+
27+
with pytest.raises(TypeError):
28+
getattr(s, op_name)(skipna=skipna)
29+
30+
@pytest.mark.parametrize('skipna', [True, False])
31+
def test_reduce_series_boolean(self, data, all_boolean_reductions, skipna):
32+
op_name = all_boolean_reductions
33+
s = pd.Series(data)
34+
35+
with pytest.raises(TypeError):
36+
getattr(s, op_name)(skipna=skipna)
37+
38+
39+
class BaseNumericReduceTests(BaseReduceTests):
40+
41+
@pytest.mark.parametrize('skipna', [True, False])
42+
def test_reduce_series(self, data, all_numeric_reductions, skipna):
43+
op_name = all_numeric_reductions
44+
s = pd.Series(data)
45+
46+
# min/max with empty produce numpy warnings
47+
with warnings.catch_warnings():
48+
warnings.simplefilter("ignore", RuntimeWarning)
49+
self.check_reduce(s, op_name, skipna)
50+
51+
52+
class BaseBooleanReduceTests(BaseReduceTests):
53+
54+
@pytest.mark.parametrize('skipna', [True, False])
55+
def test_reduce_series(self, data, all_boolean_reductions, skipna):
56+
op_name = all_boolean_reductions
57+
s = pd.Series(data)
58+
self.check_reduce(s, op_name, skipna)

0 commit comments

Comments
 (0)