@@ -69,6 +69,12 @@ def _quantile_75(x):
69
69
``False`` for better performance. Note this does not influence
70
70
the order of observations within each group. Groupby preserves
71
71
the order of rows within each group.
72
+ group_keys : bool, optional
73
+ When calling apply and the ``by`` argument produces a like-indexed
74
+ result, add group keys to index to identify pieces. By default group
75
+ keys are not included when the result's index (and column) labels match
76
+ the inputs, and are included otherwise. This argument has no effect if
77
+ the result produced is not like-indexed with respect to the input.
72
78
{ret}
73
79
Examples
74
80
--------
@@ -135,6 +141,32 @@ def _quantile_75(x):
135
141
Type
136
142
Wild 185.0
137
143
Captive 210.0
144
+
145
+ >>> df = cudf.DataFrame({{'A': 'a a b'.split(),
146
+ ... 'B': [1,2,3],
147
+ ... 'C': [4,6,5]}})
148
+ >>> g1 = df.groupby('A', group_keys=False)
149
+ >>> g2 = df.groupby('A', group_keys=True)
150
+
151
+ Notice that ``g1`` have ``g2`` have two groups, ``a`` and ``b``, and only
152
+ differ in their ``group_keys`` argument. Calling `apply` in various ways,
153
+ we can get different grouping results:
154
+
155
+ >>> g1[['B', 'C']].apply(lambda x: x / x.sum())
156
+ B C
157
+ 0 0.333333 0.4
158
+ 1 0.666667 0.6
159
+ 2 1.000000 1.0
160
+
161
+ In the above, the groups are not part of the index. We can have them included
162
+ by using ``g2`` where ``group_keys=True``:
163
+
164
+ >>> g2[['B', 'C']].apply(lambda x: x / x.sum())
165
+ B C
166
+ A
167
+ a 0 0.333333 0.4
168
+ 1 0.666667 0.6
169
+ b 2 1.000000 1.0
138
170
"""
139
171
)
140
172
@@ -174,7 +206,14 @@ class GroupBy(Serializable, Reducible, Scannable):
174
206
_MAX_GROUPS_BEFORE_WARN = 100
175
207
176
208
def __init__ (
177
- self , obj , by = None , level = None , sort = False , as_index = True , dropna = True
209
+ self ,
210
+ obj ,
211
+ by = None ,
212
+ level = None ,
213
+ sort = False ,
214
+ as_index = True ,
215
+ dropna = True ,
216
+ group_keys = True ,
178
217
):
179
218
"""
180
219
Group a DataFrame or Series by a set of columns.
@@ -210,6 +249,7 @@ def __init__(
210
249
self ._level = level
211
250
self ._sort = sort
212
251
self ._dropna = dropna
252
+ self ._group_keys = group_keys
213
253
214
254
if isinstance (by , _Grouping ):
215
255
by ._obj = self .obj
@@ -544,7 +584,9 @@ def _grouped(self):
544
584
grouped_key_cols , grouped_value_cols , offsets = self ._groupby .groups (
545
585
[* self .obj ._index ._columns , * self .obj ._columns ]
546
586
)
547
- grouped_keys = cudf .core .index ._index_from_columns (grouped_key_cols )
587
+ grouped_keys = cudf .core .index ._index_from_columns (
588
+ grouped_key_cols , name = self .grouping .keys .name
589
+ )
548
590
grouped_values = self .obj ._from_columns_like_self (
549
591
grouped_value_cols ,
550
592
column_names = self .obj ._column_names ,
@@ -707,7 +749,7 @@ def mult(df):
707
749
"""
708
750
if not callable (function ):
709
751
raise TypeError (f"type { type (function )} is not callable" )
710
- group_names , offsets , _ , grouped_values = self ._grouped ()
752
+ group_names , offsets , group_keys , grouped_values = self ._grouped ()
711
753
712
754
ngroups = len (offsets ) - 1
713
755
if ngroups > self ._MAX_GROUPS_BEFORE_WARN :
@@ -726,14 +768,21 @@ def mult(df):
726
768
if cudf .api .types .is_scalar (chunk_results [0 ]):
727
769
result = cudf .Series (chunk_results , index = group_names )
728
770
result .index .names = self .grouping .names
729
- elif isinstance (chunk_results [0 ], cudf .Series ):
730
- if isinstance (self .obj , cudf .DataFrame ):
771
+ else :
772
+ if isinstance (chunk_results [0 ], cudf .Series ) and isinstance (
773
+ self .obj , cudf .DataFrame
774
+ ):
731
775
result = cudf .concat (chunk_results , axis = 1 ).T
732
776
result .index .names = self .grouping .names
733
777
else :
734
778
result = cudf .concat (chunk_results )
735
- else :
736
- result = cudf .concat (chunk_results )
779
+ if self ._group_keys :
780
+ result .index = cudf .MultiIndex ._from_data (
781
+ {
782
+ group_keys .name : group_keys ._column ,
783
+ None : grouped_values .index ._column ,
784
+ }
785
+ )
737
786
738
787
if self ._sort :
739
788
result = result .sort_index ()
@@ -1582,7 +1631,10 @@ class DataFrameGroupBy(GroupBy, GetAttrGetItemMixin):
1582
1631
1583
1632
def __getitem__ (self , key ):
1584
1633
return self .obj [key ].groupby (
1585
- by = self .grouping .keys , dropna = self ._dropna , sort = self ._sort
1634
+ by = self .grouping .keys ,
1635
+ dropna = self ._dropna ,
1636
+ sort = self ._sort ,
1637
+ group_keys = self ._group_keys ,
1586
1638
)
1587
1639
1588
1640
0 commit comments