Skip to content

Commit 9b4d0f1

Browse files
TomAugspurgerjreback
authored andcommitted
ENH: Support ExtensionArray in Groupby (#20502)
1 parent 48f9a9a commit 9b4d0f1

File tree

6 files changed

+115
-8
lines changed

6 files changed

+115
-8
lines changed

pandas/core/groupby.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
DataError, SpecificationError)
4545
from pandas.core.index import (Index, MultiIndex,
4646
CategoricalIndex, _ensure_index)
47-
from pandas.core.arrays import Categorical
47+
from pandas.core.arrays import ExtensionArray, Categorical
4848
from pandas.core.frame import DataFrame
4949
from pandas.core.generic import NDFrame, _shared_docs
5050
from pandas.core.internals import BlockManager, make_block
@@ -2968,7 +2968,7 @@ def __init__(self, index, grouper=None, obj=None, name=None, level=None,
29682968

29692969
# no level passed
29702970
elif not isinstance(self.grouper,
2971-
(Series, Index, Categorical, np.ndarray)):
2971+
(Series, Index, ExtensionArray, np.ndarray)):
29722972
if getattr(self.grouper, 'ndim', 1) != 1:
29732973
t = self.name or str(type(self.grouper))
29742974
raise ValueError("Grouper for '%s' not 1-dimensional" % t)

pandas/tests/extension/base/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class TestMyDtype(BaseDtypeTests):
4444
from .constructors import BaseConstructorsTests # noqa
4545
from .dtype import BaseDtypeTests # noqa
4646
from .getitem import BaseGetitemTests # noqa
47+
from .groupby import BaseGroupbyTests # noqa
4748
from .interface import BaseInterfaceTests # noqa
4849
from .methods import BaseMethodsTests # noqa
4950
from .missing import BaseMissingTests # noqa
+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import pytest
2+
3+
import pandas.util.testing as tm
4+
import pandas as pd
5+
from .base import BaseExtensionTests
6+
7+
8+
class BaseGroupbyTests(BaseExtensionTests):
9+
"""Groupby-specific tests."""
10+
11+
def test_grouping_grouper(self, data_for_grouping):
12+
df = pd.DataFrame({
13+
"A": ["B", "B", None, None, "A", "A", "B", "C"],
14+
"B": data_for_grouping
15+
})
16+
gr1 = df.groupby("A").grouper.groupings[0]
17+
gr2 = df.groupby("B").grouper.groupings[0]
18+
19+
tm.assert_numpy_array_equal(gr1.grouper, df.A.values)
20+
tm.assert_extension_array_equal(gr2.grouper, data_for_grouping)
21+
22+
@pytest.mark.parametrize('as_index', [True, False])
23+
def test_groupby_extension_agg(self, as_index, data_for_grouping):
24+
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4],
25+
"B": data_for_grouping})
26+
result = df.groupby("B", as_index=as_index).A.mean()
27+
_, index = pd.factorize(data_for_grouping, sort=True)
28+
# TODO(ExtensionIndex): remove astype
29+
index = pd.Index(index.astype(object), name="B")
30+
expected = pd.Series([3, 1, 4], index=index, name="A")
31+
if as_index:
32+
self.assert_series_equal(result, expected)
33+
else:
34+
expected = expected.reset_index()
35+
self.assert_frame_equal(result, expected)
36+
37+
def test_groupby_extension_no_sort(self, data_for_grouping):
38+
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4],
39+
"B": data_for_grouping})
40+
result = df.groupby("B", sort=False).A.mean()
41+
_, index = pd.factorize(data_for_grouping, sort=False)
42+
# TODO(ExtensionIndex): remove astype
43+
index = pd.Index(index.astype(object), name="B")
44+
expected = pd.Series([1, 3, 4], index=index, name="A")
45+
self.assert_series_equal(result, expected)
46+
47+
def test_groupby_extension_transform(self, data_for_grouping):
48+
valid = data_for_grouping[~data_for_grouping.isna()]
49+
df = pd.DataFrame({"A": [1, 1, 3, 3, 1, 4],
50+
"B": valid})
51+
52+
result = df.groupby("B").A.transform(len)
53+
expected = pd.Series([3, 3, 2, 2, 3, 1], name="A")
54+
55+
self.assert_series_equal(result, expected)
56+
57+
@pytest.mark.parametrize('op', [
58+
lambda x: 1,
59+
lambda x: [1] * len(x),
60+
lambda x: pd.Series([1] * len(x)),
61+
lambda x: x,
62+
], ids=['scalar', 'list', 'series', 'object'])
63+
def test_groupby_extension_apply(self, data_for_grouping, op):
64+
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4],
65+
"B": data_for_grouping})
66+
df.groupby("B").apply(op)
67+
df.groupby("B").A.apply(op)
68+
df.groupby("A").apply(op)
69+
df.groupby("A").B.apply(op)

pandas/tests/extension/decimal/test_decimal.py

+4
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ class TestCasting(BaseDecimal, base.BaseCastingTests):
127127
pass
128128

129129

130+
class TestGroupby(BaseDecimal, base.BaseGroupbyTests):
131+
pass
132+
133+
130134
def test_series_constructor_coerce_data_to_extension_dtype_raises():
131135
xpr = ("Cannot cast data to extension dtype 'decimal'. Pass the "
132136
"extension array directly.")

pandas/tests/extension/json/array.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def _concat_same_type(cls, to_concat):
113113
return cls(data)
114114

115115
def _values_for_factorize(self):
116-
frozen = tuple(tuple(x.items()) for x in self)
117-
return np.array(frozen, dtype=object), ()
116+
frozen = self._values_for_argsort()
117+
return frozen, ()
118118

119119
def _values_for_argsort(self):
120120
# Disable NumPy's shape inference by including an empty tuple...

pandas/tests/extension/json/test_json.py

+37-4
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,12 @@ def test_fillna_frame(self):
8989
"""We treat dictionaries as a mapping in fillna, not a scalar."""
9090

9191

92-
class TestMethods(base.BaseMethodsTests):
93-
unhashable = pytest.mark.skip(reason="Unhashable")
94-
unstable = pytest.mark.skipif(not PY36, # 3.6 or higher
95-
reason="Dictionary order unstable")
92+
unhashable = pytest.mark.skip(reason="Unhashable")
93+
unstable = pytest.mark.skipif(not PY36, # 3.6 or higher
94+
reason="Dictionary order unstable")
95+
9696

97+
class TestMethods(base.BaseMethodsTests):
9798
@unhashable
9899
def test_value_counts(self, all_data, dropna):
99100
pass
@@ -118,6 +119,7 @@ def test_sort_values(self, data_for_sorting, ascending):
118119
super(TestMethods, self).test_sort_values(
119120
data_for_sorting, ascending)
120121

122+
@unstable
121123
@pytest.mark.parametrize('ascending', [True, False])
122124
def test_sort_values_missing(self, data_missing_for_sorting, ascending):
123125
super(TestMethods, self).test_sort_values_missing(
@@ -126,3 +128,34 @@ def test_sort_values_missing(self, data_missing_for_sorting, ascending):
126128

127129
class TestCasting(base.BaseCastingTests):
128130
pass
131+
132+
133+
class TestGroupby(base.BaseGroupbyTests):
134+
135+
@unhashable
136+
def test_groupby_extension_transform(self):
137+
"""
138+
This currently fails in Series.name.setter, since the
139+
name must be hashable, but the value is a dictionary.
140+
I think this is what we want, i.e. `.name` should be the original
141+
values, and not the values for factorization.
142+
"""
143+
144+
@unhashable
145+
def test_groupby_extension_apply(self):
146+
"""
147+
This fails in Index._do_unique_check with
148+
149+
> hash(val)
150+
E TypeError: unhashable type: 'UserDict' with
151+
152+
I suspect that once we support Index[ExtensionArray],
153+
we'll be able to dispatch unique.
154+
"""
155+
156+
@unstable
157+
@pytest.mark.parametrize('as_index', [True, False])
158+
def test_groupby_extension_agg(self, as_index, data_for_grouping):
159+
super(TestGroupby, self).test_groupby_extension_agg(
160+
as_index, data_for_grouping
161+
)

0 commit comments

Comments
 (0)