Skip to content

Commit 067335e

Browse files
committed
ENH: add groupby & reduce support to EA
closes pandas-dev#21789 closes pandas-dev#22346
1 parent 40dfadd commit 067335e

File tree

9 files changed

+193
-21
lines changed

9 files changed

+193
-21
lines changed

doc/source/whatsnew/v0.24.0.txt

+8-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Pandas has gained the ability to hold integer dtypes with missing values. This l
4242
Here is an example of the usage.
4343

4444
We can construct a ``Series`` with the specified dtype. The dtype string ``Int64`` is a pandas ``ExtensionDtype``. Specifying a list or array using the traditional missing value
45-
marker of ``np.nan`` will infer to integer dtype. The display of the ``Series`` will also use the ``NaN`` to indicate missing values in string outputs. (:issue:`20700`, :issue:`20747`, :issue:`22441`)
45+
marker of ``np.nan`` will infer to integer dtype. The display of the ``Series`` will also use the ``NaN`` to indicate missing values in string outputs. (:issue:`20700`, :issue:`20747`, :issue:`22441`, :issue:`21789`, :issue:`22346`)
4646

4747
.. ipython:: python
4848

@@ -85,6 +85,13 @@ These dtypes can be merged & reshaped & casted.
8585
pd.concat([df[['A']], df[['B', 'C']]], axis=1).dtypes
8686
df['A'].astype(float)
8787

88+
Reduction and groupby operations such as 'sum' work.
89+
90+
.. ipython:: python
91+
92+
df.sum()
93+
df.groupby('B').A.sum()
94+
8895
.. warning::
8996

9097
The Integer NA support currently uses the captilized dtype version, e.g. ``Int8`` as compared to the traditional ``int8``. This may be changed at a future date.

pandas/conftest.py

+23
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,29 @@ def all_arithmetic_operators(request):
131131
return request.param
132132

133133

134+
_all_numeric_reductions = ['sum', 'max', 'min',
135+
'mean', 'prod', 'std', 'var', 'median']
136+
137+
138+
@pytest.fixture(params=_all_numeric_reductions)
139+
def all_numeric_reductions(request):
140+
"""
141+
Fixture for numeric reduction names
142+
"""
143+
return request.param
144+
145+
146+
_all_boolean_reductions = ['all', 'any']
147+
148+
149+
@pytest.fixture(params=_all_boolean_reductions)
150+
def all_boolean_reductions(request):
151+
"""
152+
Fixture for boolean reduction names
153+
"""
154+
return request.param
155+
156+
134157
_cython_table = pd.core.base.SelectionMixin._cython_table.items()
135158

136159

pandas/core/arrays/base.py

+29
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ class ExtensionArray(object):
5959
* factorize / _values_for_factorize
6060
* argsort / _values_for_argsort
6161
62+
One can implement methods to handle array reductions.
63+
64+
* _reduce
65+
6266
The remaining methods implemented on this class should be performant,
6367
as they only compose abstract methods. Still, a more efficient
6468
implementation may be available, and these methods can be overridden.
@@ -708,6 +712,31 @@ def _add_comparison_ops(cls):
708712
cls.__le__ = cls._create_comparison_method(operator.le)
709713
cls.__ge__ = cls._create_comparison_method(operator.ge)
710714

715+
def _reduce(self, op, name, axis=0, skipna=True, numeric_only=None,
716+
filter_type=None, **kwargs):
717+
"""Return a scalar result of performing the op
718+
719+
Parameters
720+
----------
721+
op : callable
722+
function to apply to the array
723+
name : str
724+
name of the function
725+
axis : int, default 0
726+
axis over which to apply, defined as 0 currently
727+
skipna : bool, default True
728+
if True, skip NaN values
729+
numeric_only : bool, optional
730+
if True, only perform numeric ops
731+
filter_type : str, optional
732+
kwargs : dict
733+
734+
Returns
735+
-------
736+
scalar
737+
"""
738+
raise AbstractMethodError(self)
739+
711740

712741
class ExtensionScalarOpsMixin(ExtensionOpsMixin):
713742
"""A mixin for defining the arithmetic and logical operations on

pandas/core/arrays/integer.py

+49
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,55 @@ def cmp_method(self, other):
529529
name = '__{name}__'.format(name=op.__name__)
530530
return set_function_name(cmp_method, name, cls)
531531

532+
def _reduce(self, op, name, axis=0, skipna=True, numeric_only=None,
533+
filter_type=None, **kwds):
534+
"""Return a scalar result of performing the op
535+
536+
Parameters
537+
----------
538+
op : callable
539+
function to apply to the array
540+
name : str
541+
name of the function
542+
axis : int, default 0
543+
axis over which to apply, defined as 0 currently
544+
skipna : bool, default True
545+
if True, skip NaN values
546+
numeric_only : bool, optional
547+
if True, only perform numeric ops
548+
filter_type : str, optional
549+
kwds : dict
550+
551+
Returns
552+
-------
553+
scalar
554+
"""
555+
556+
data = self._data
557+
mask = self._mask
558+
559+
# coerce to a nan-aware float if needed
560+
if mask.any():
561+
data = self._data.astype('float64')
562+
data[mask] = self._na_value
563+
564+
result = op(data, axis=axis, skipna=skipna)
565+
566+
# if we have a boolean op, provide coercion back to a bool
567+
# type if possible
568+
if name in ['any', 'all']:
569+
if is_integer(result) or is_float(result):
570+
result = bool(int(result))
571+
572+
# if we have a numeric op, provide coercion back to an integer
573+
# type if possible
574+
elif not isna(result):
575+
int_result = int(result)
576+
if int_result == result:
577+
result = int_result
578+
579+
return result
580+
532581
def _maybe_mask_result(self, result, mask, other, op_name):
533582
"""
534583
Parameters

pandas/tests/arrays/test_integer.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -587,15 +587,18 @@ def test_cross_type_arithmetic():
587587
tm.assert_series_equal(result, expected)
588588

589589

590-
def test_groupby_mean_included():
590+
@pytest.mark.parametrize('op', ['sum', 'min', 'max'])
591+
def test_preserve_groupby_dtypes(op):
592+
# TODO(#22346): preserve Int64 dtype
593+
# for ops that enable (mean would actually work here
594+
# but generally it is a float return value)
591595
df = pd.DataFrame({
592596
"A": ['a', 'b', 'b'],
593597
"B": [1, None, 3],
594598
"C": integer_array([1, None, 3], dtype='Int64'),
595599
})
596600

597-
result = df.groupby("A").sum()
598-
# TODO(#22346): preserve Int64 dtype
601+
result = getattr(df.groupby("A"), op)()
599602
expected = pd.DataFrame({
600603
"B": np.array([1.0, 3.0]),
601604
"C": np.array([1, 3], dtype="int64")

pandas/tests/extension/base/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class TestMyDtype(BaseDtypeTests):
4848
from .interface import BaseInterfaceTests # noqa
4949
from .methods import BaseMethodsTests # noqa
5050
from .ops import BaseArithmeticOpsTests, BaseComparisonOpsTests, BaseOpsUtil # noqa
51+
from .reduce import BaseNumericReduceTests, BaseBooleanReduceTests # noqa
5152
from .missing import BaseMissingTests # noqa
5253
from .reshaping import BaseReshapingTests # noqa
5354
from .setitem import BaseSetitemTests # noqa

pandas/tests/extension/base/groupby.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping):
2525
"B": data_for_grouping})
2626
result = df.groupby("B", as_index=as_index).A.mean()
2727
_, index = pd.factorize(data_for_grouping, sort=True)
28-
# TODO(ExtensionIndex): remove astype
29-
index = pd.Index(index.astype(object), name="B")
28+
29+
index = pd.Index(index, name="B")
3030
expected = pd.Series([3, 1, 4], index=index, name="A")
3131
if as_index:
3232
self.assert_series_equal(result, expected)
@@ -39,8 +39,8 @@ def test_groupby_extension_no_sort(self, data_for_grouping):
3939
"B": data_for_grouping})
4040
result = df.groupby("B", sort=False).A.mean()
4141
_, index = pd.factorize(data_for_grouping, sort=False)
42-
# TODO(ExtensionIndex): remove astype
43-
index = pd.Index(index.astype(object), name="B")
42+
43+
index = pd.Index(index, name="B")
4444
expected = pd.Series([1, 3, 4], index=index, name="A")
4545
self.assert_series_equal(result, expected)
4646

pandas/tests/extension/base/reduce.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import warnings
2+
import pytest
3+
import pandas.util.testing as tm
4+
import pandas as pd
5+
from .base import BaseExtensionTests
6+
7+
8+
class BaseReduceTests(BaseExtensionTests):
9+
"""
10+
Reduction specific tests. Generally these only
11+
make sense for numeric/boolean operations.
12+
"""
13+
def check_reduce(self, s, op_name, skipna):
14+
result = getattr(s, op_name)(skipna=skipna)
15+
expected = getattr(s.astype('float64'), op_name)(skipna=skipna)
16+
tm.assert_almost_equal(result, expected)
17+
18+
19+
class BaseNumericReduceTests(BaseReduceTests):
20+
21+
@pytest.mark.parametrize('skipna', [True, False])
22+
def test_reduce_series(self, data, all_numeric_reductions, skipna):
23+
op_name = all_numeric_reductions
24+
s = pd.Series(data)
25+
26+
# min/max with empty produce numpy warnings
27+
with warnings.catch_warnings(record=True):
28+
warnings.simplefilter("ignore", RuntimeWarning)
29+
self.check_reduce(s, op_name, skipna)
30+
31+
32+
class BaseBooleanReduceTests(BaseReduceTests):
33+
34+
@pytest.mark.parametrize('skipna', [True, False])
35+
def test_reduce_series(self, data, all_boolean_reductions, skipna):
36+
op_name = all_boolean_reductions
37+
s = pd.Series(data)
38+
self.check_reduce(s, op_name, skipna)

pandas/tests/extension/test_integer.py

+35-13
Original file line numberDiff line numberDiff line change
@@ -210,17 +210,39 @@ class TestCasting(base.BaseCastingTests):
210210

211211
class TestGroupby(base.BaseGroupbyTests):
212212

213-
@pytest.mark.xfail(reason="groupby not working", strict=True)
214-
def test_groupby_extension_no_sort(self, data_for_grouping):
215-
super(TestGroupby, self).test_groupby_extension_no_sort(
216-
data_for_grouping)
217-
218-
@pytest.mark.parametrize('as_index', [
219-
pytest.param(True,
220-
marks=pytest.mark.xfail(reason="groupby not working",
221-
strict=True)),
222-
False
223-
])
213+
@pytest.mark.parametrize('as_index', [True, False])
224214
def test_groupby_extension_agg(self, as_index, data_for_grouping):
225-
super(TestGroupby, self).test_groupby_extension_agg(
226-
as_index, data_for_grouping)
215+
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4],
216+
"B": data_for_grouping})
217+
result = df.groupby("B", as_index=as_index).A.mean()
218+
_, index = pd.factorize(data_for_grouping, sort=True)
219+
220+
# TODO(ExtensionIndex): remove coercion to object
221+
# we don't have an easy way to represent an EA as an Index object
222+
index = pd.Index(index, name="B", dtype=object)
223+
expected = pd.Series([3, 1, 4], index=index, name="A")
224+
if as_index:
225+
self.assert_series_equal(result, expected)
226+
else:
227+
expected = expected.reset_index()
228+
self.assert_frame_equal(result, expected)
229+
230+
def test_groupby_extension_no_sort(self, data_for_grouping):
231+
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4],
232+
"B": data_for_grouping})
233+
result = df.groupby("B", sort=False).A.mean()
234+
_, index = pd.factorize(data_for_grouping, sort=False)
235+
236+
# TODO(ExtensionIndex): remove coercion to object
237+
# we don't have an easy way to represent an EA as an Index object
238+
index = pd.Index(index, name="B", dtype=object)
239+
expected = pd.Series([1, 3, 4], index=index, name="A")
240+
self.assert_series_equal(result, expected)
241+
242+
243+
class TestNumericReduce(base.BaseNumericReduceTests):
244+
pass
245+
246+
247+
class TestBooleanReduce(base.BaseBooleanReduceTests):
248+
pass

0 commit comments

Comments
 (0)