Skip to content

Commit d3e8f6d

Browse files
Add support for group_keys in groupby (#11659)
- [x] This PR adds support for `group_keys` in `groupby`. Starting pandas 1.5.0, issues around `group_keys` have been resolved: pandas-dev/pandas#34998 pandas-dev/pandas#47185 - [x] This PR defaults `group_keys` to `False` which is the same as what pandas is going to be defaulting to in the future version. - [x] Required to unblock `pandas-1.5.0` upgrade in cudf: #11617 Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Bradley Dice (https://github.com/bdice) - Ashwin Srinath (https://github.com/shwina) URL: #11659
1 parent 0684ee1 commit d3e8f6d

File tree

6 files changed

+101
-17
lines changed

6 files changed

+101
-17
lines changed

python/cudf/cudf/core/_compat.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212
PANDAS_GE_133 = PANDAS_VERSION >= version.parse("1.3.3")
1313
PANDAS_GE_134 = PANDAS_VERSION >= version.parse("1.3.4")
1414
PANDAS_LT_140 = PANDAS_VERSION < version.parse("1.4.0")
15+
PANDAS_GE_150 = PANDAS_VERSION >= version.parse("1.5.0")

python/cudf/cudf/core/dataframe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3836,7 +3836,7 @@ def groupby(
38363836
level=None,
38373837
as_index=True,
38383838
sort=False,
3839-
group_keys=True,
3839+
group_keys=False,
38403840
squeeze=False,
38413841
observed=False,
38423842
dropna=True,

python/cudf/cudf/core/groupby/groupby.py

+60-8
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ def _quantile_75(x):
6969
``False`` for better performance. Note this does not influence
7070
the order of observations within each group. Groupby preserves
7171
the order of rows within each group.
72+
group_keys : bool, optional
73+
When calling apply and the ``by`` argument produces a like-indexed
74+
result, add group keys to index to identify pieces. By default group
75+
keys are not included when the result's index (and column) labels match
76+
the inputs, and are included otherwise. This argument has no effect if
77+
the result produced is not like-indexed with respect to the input.
7278
{ret}
7379
Examples
7480
--------
@@ -135,6 +141,32 @@ def _quantile_75(x):
135141
Type
136142
Wild 185.0
137143
Captive 210.0
144+
145+
>>> df = cudf.DataFrame({{'A': 'a a b'.split(),
146+
... 'B': [1,2,3],
147+
... 'C': [4,6,5]}})
148+
>>> g1 = df.groupby('A', group_keys=False)
149+
>>> g2 = df.groupby('A', group_keys=True)
150+
151+
Notice that ``g1`` have ``g2`` have two groups, ``a`` and ``b``, and only
152+
differ in their ``group_keys`` argument. Calling `apply` in various ways,
153+
we can get different grouping results:
154+
155+
>>> g1[['B', 'C']].apply(lambda x: x / x.sum())
156+
B C
157+
0 0.333333 0.4
158+
1 0.666667 0.6
159+
2 1.000000 1.0
160+
161+
In the above, the groups are not part of the index. We can have them included
162+
by using ``g2`` where ``group_keys=True``:
163+
164+
>>> g2[['B', 'C']].apply(lambda x: x / x.sum())
165+
B C
166+
A
167+
a 0 0.333333 0.4
168+
1 0.666667 0.6
169+
b 2 1.000000 1.0
138170
"""
139171
)
140172

@@ -174,7 +206,14 @@ class GroupBy(Serializable, Reducible, Scannable):
174206
_MAX_GROUPS_BEFORE_WARN = 100
175207

176208
def __init__(
177-
self, obj, by=None, level=None, sort=False, as_index=True, dropna=True
209+
self,
210+
obj,
211+
by=None,
212+
level=None,
213+
sort=False,
214+
as_index=True,
215+
dropna=True,
216+
group_keys=True,
178217
):
179218
"""
180219
Group a DataFrame or Series by a set of columns.
@@ -210,6 +249,7 @@ def __init__(
210249
self._level = level
211250
self._sort = sort
212251
self._dropna = dropna
252+
self._group_keys = group_keys
213253

214254
if isinstance(by, _Grouping):
215255
by._obj = self.obj
@@ -544,7 +584,9 @@ def _grouped(self):
544584
grouped_key_cols, grouped_value_cols, offsets = self._groupby.groups(
545585
[*self.obj._index._columns, *self.obj._columns]
546586
)
547-
grouped_keys = cudf.core.index._index_from_columns(grouped_key_cols)
587+
grouped_keys = cudf.core.index._index_from_columns(
588+
grouped_key_cols, name=self.grouping.keys.name
589+
)
548590
grouped_values = self.obj._from_columns_like_self(
549591
grouped_value_cols,
550592
column_names=self.obj._column_names,
@@ -707,7 +749,7 @@ def mult(df):
707749
"""
708750
if not callable(function):
709751
raise TypeError(f"type {type(function)} is not callable")
710-
group_names, offsets, _, grouped_values = self._grouped()
752+
group_names, offsets, group_keys, grouped_values = self._grouped()
711753

712754
ngroups = len(offsets) - 1
713755
if ngroups > self._MAX_GROUPS_BEFORE_WARN:
@@ -726,14 +768,21 @@ def mult(df):
726768
if cudf.api.types.is_scalar(chunk_results[0]):
727769
result = cudf.Series(chunk_results, index=group_names)
728770
result.index.names = self.grouping.names
729-
elif isinstance(chunk_results[0], cudf.Series):
730-
if isinstance(self.obj, cudf.DataFrame):
771+
else:
772+
if isinstance(chunk_results[0], cudf.Series) and isinstance(
773+
self.obj, cudf.DataFrame
774+
):
731775
result = cudf.concat(chunk_results, axis=1).T
732776
result.index.names = self.grouping.names
733777
else:
734778
result = cudf.concat(chunk_results)
735-
else:
736-
result = cudf.concat(chunk_results)
779+
if self._group_keys:
780+
result.index = cudf.MultiIndex._from_data(
781+
{
782+
group_keys.name: group_keys._column,
783+
None: grouped_values.index._column,
784+
}
785+
)
737786

738787
if self._sort:
739788
result = result.sort_index()
@@ -1582,7 +1631,10 @@ class DataFrameGroupBy(GroupBy, GetAttrGetItemMixin):
15821631

15831632
def __getitem__(self, key):
15841633
return self.obj[key].groupby(
1585-
by=self.grouping.keys, dropna=self._dropna, sort=self._sort
1634+
by=self.grouping.keys,
1635+
dropna=self._dropna,
1636+
sort=self._sort,
1637+
group_keys=self._group_keys,
15861638
)
15871639

15881640

python/cudf/cudf/core/indexed_frame.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -3535,19 +3535,14 @@ def groupby(
35353535
level=None,
35363536
as_index=True,
35373537
sort=False,
3538-
group_keys=True,
3538+
group_keys=False,
35393539
squeeze=False,
35403540
observed=False,
35413541
dropna=True,
35423542
):
35433543
if axis not in (0, "index"):
35443544
raise NotImplementedError("axis parameter is not yet implemented")
35453545

3546-
if group_keys is not True:
3547-
raise NotImplementedError(
3548-
"The group_keys keyword is not yet implemented"
3549-
)
3550-
35513546
if squeeze is not False:
35523547
raise NotImplementedError(
35533548
"squeeze parameter is not yet implemented"
@@ -3562,6 +3557,8 @@ def groupby(
35623557
raise TypeError(
35633558
"groupby() requires either by or level to be specified."
35643559
)
3560+
if group_keys is None:
3561+
group_keys = False
35653562

35663563
return (
35673564
self.__class__._resampler(self, by=by)
@@ -3573,6 +3570,7 @@ def groupby(
35733570
as_index=as_index,
35743571
dropna=dropna,
35753572
sort=sort,
3573+
group_keys=group_keys,
35763574
)
35773575
)
35783576

python/cudf/cudf/core/series.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3075,7 +3075,7 @@ def groupby(
30753075
level=None,
30763076
as_index=True,
30773077
sort=False,
3078-
group_keys=True,
3078+
group_keys=False,
30793079
squeeze=False,
30803080
observed=False,
30813081
dropna=True,

python/cudf/cudf/tests/test_groupby.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414

1515
import cudf
1616
from cudf import DataFrame, Series
17-
from cudf.core._compat import PANDAS_GE_110, PANDAS_GE_130, PANDAS_LT_140
17+
from cudf.core._compat import (
18+
PANDAS_GE_110,
19+
PANDAS_GE_130,
20+
PANDAS_GE_150,
21+
PANDAS_LT_140,
22+
)
1823
from cudf.testing._utils import (
1924
DATETIME_TYPES,
2025
SIGNED_TYPES,
@@ -2677,3 +2682,31 @@ def test_groupby_pct_change_empty_columns():
26772682
expected = pdf.groupby("id").pct_change()
26782683

26792684
assert_eq(expected, actual)
2685+
2686+
2687+
@pytest.mark.parametrize(
2688+
"group_keys",
2689+
[
2690+
None,
2691+
pytest.param(
2692+
True,
2693+
marks=pytest.mark.xfail(
2694+
condition=not PANDAS_GE_150,
2695+
reason="https://github.com/pandas-dev/pandas/pull/34998",
2696+
),
2697+
),
2698+
False,
2699+
],
2700+
)
2701+
def test_groupby_group_keys(group_keys):
2702+
gdf = cudf.DataFrame(
2703+
{"A": "a a b".split(), "B": [1, 2, 3], "C": [4, 6, 5]}
2704+
)
2705+
pdf = gdf.to_pandas()
2706+
2707+
g_group = gdf.groupby("A", group_keys=group_keys)
2708+
p_group = pdf.groupby("A", group_keys=group_keys)
2709+
2710+
actual = g_group[["B", "C"]].apply(lambda x: x / x.sum())
2711+
expected = p_group[["B", "C"]].apply(lambda x: x / x.sum())
2712+
assert_eq(actual, expected)

0 commit comments

Comments
 (0)