@@ -653,37 +653,37 @@ def _iterate_slices(self):
653
653
def transform (self , func , * args , ** kwargs ):
654
654
raise AbstractMethodError (self )
655
655
656
- def _cumcount_array (self , arr = None , ascending = True ):
656
+ def _cumcount_array (self , ascending = True ):
657
657
"""
658
- arr is where cumcount gets its values from
658
+ Parameters
659
+ ----------
660
+ ascending : bool, default True
661
+ If False, number in reverse, from length of group - 1 to 0.
659
662
660
663
Note
661
664
----
662
665
this is currently implementing sort=False
663
666
(though the default is sort=True) for groupby in general
664
667
"""
665
- if arr is None :
666
- arr = np .arange (self .grouper ._max_groupsize , dtype = 'int64' )
667
-
668
- len_index = len (self ._selected_obj .index )
669
- cumcounts = np .zeros (len_index , dtype = arr .dtype )
670
- if not len_index :
671
- return cumcounts
668
+ ids , _ , ngroups = self .grouper .group_info
669
+ sorter = _get_group_index_sorter (ids , ngroups )
670
+ ids , count = ids [sorter ], len (ids )
672
671
673
- indices , values = [], []
674
- for v in self .indices .values ():
675
- indices .append (v )
672
+ if count == 0 :
673
+ return np .empty (0 , dtype = np .int64 )
676
674
677
- if ascending :
678
- values .append (arr [:len (v )])
679
- else :
680
- values .append (arr [len (v ) - 1 ::- 1 ])
675
+ run = np .r_ [True , ids [:- 1 ] != ids [1 :]]
676
+ rep = np .diff (np .r_ [np .nonzero (run )[0 ], count ])
677
+ out = (~ run ).cumsum ()
681
678
682
- indices = np .concatenate (indices )
683
- values = np .concatenate (values )
684
- cumcounts [indices ] = values
679
+ if ascending :
680
+ out -= np .repeat (out [run ], rep )
681
+ else :
682
+ out = np .repeat (out [np .r_ [run [1 :], True ]], rep ) - out
685
683
686
- return cumcounts
684
+ rev = np .empty (count , dtype = np .intp )
685
+ rev [sorter ] = np .arange (count , dtype = np .intp )
686
+ return out [rev ].astype (np .int64 , copy = False )
687
687
688
688
def _index_with_as_index (self , b ):
689
689
"""
@@ -1170,47 +1170,21 @@ def nth(self, n, dropna=None):
1170
1170
else :
1171
1171
raise TypeError ("n needs to be an int or a list/set/tuple of ints" )
1172
1172
1173
- m = self .grouper ._max_groupsize
1174
- # filter out values that are outside [-m, m)
1175
- pos_nth_values = [i for i in nth_values if i >= 0 and i < m ]
1176
- neg_nth_values = [i for i in nth_values if i < 0 and i >= - m ]
1177
-
1173
+ nth_values = np .array (nth_values , dtype = np .intp )
1178
1174
self ._set_selection_from_grouper ()
1179
- if not dropna : # good choice
1180
- if not pos_nth_values and not neg_nth_values :
1181
- # no valid nth values
1182
- return self ._selected_obj .loc [[]]
1183
-
1184
- rng = np .zeros (m , dtype = bool )
1185
- for i in pos_nth_values :
1186
- rng [i ] = True
1187
- is_nth = self ._cumcount_array (rng )
1188
1175
1189
- if neg_nth_values :
1190
- rng = np .zeros (m , dtype = bool )
1191
- for i in neg_nth_values :
1192
- rng [- i - 1 ] = True
1193
- is_nth |= self ._cumcount_array (rng , ascending = False )
1176
+ if not dropna :
1177
+ mask = np .in1d (self ._cumcount_array (), nth_values ) | \
1178
+ np .in1d (self ._cumcount_array (ascending = False ) + 1 , - nth_values )
1194
1179
1195
- result = self ._selected_obj [is_nth ]
1180
+ out = self ._selected_obj [mask ]
1181
+ if not self .as_index :
1182
+ return out
1196
1183
1197
- # the result index
1198
- if self .as_index :
1199
- ax = self .obj ._info_axis
1200
- names = self .grouper .names
1201
- if self .obj .ndim == 1 :
1202
- # this is a pass-thru
1203
- pass
1204
- elif all ([x in ax for x in names ]):
1205
- indicies = [self .obj [name ][is_nth ] for name in names ]
1206
- result .index = MultiIndex .from_arrays (
1207
- indicies ).set_names (names )
1208
- elif self ._group_selection is not None :
1209
- result .index = self .obj ._get_axis (self .axis )[is_nth ]
1210
-
1211
- result = result .sort_index ()
1184
+ ids , _ , _ = self .grouper .group_info
1185
+ out .index = self .grouper .result_index [ids [mask ]]
1212
1186
1213
- return result
1187
+ return out . sort_index () if self . sort else out
1214
1188
1215
1189
if isinstance (self ._selected_obj , DataFrame ) and \
1216
1190
dropna not in ['any' , 'all' ]:
@@ -1241,8 +1215,8 @@ def nth(self, n, dropna=None):
1241
1215
axis = self .axis , level = self .level ,
1242
1216
sort = self .sort )
1243
1217
1244
- sizes = dropped .groupby (grouper ). size ( )
1245
- result = dropped . groupby ( grouper ) .nth (n )
1218
+ grb = dropped .groupby (grouper , as_index = self . as_index , sort = self . sort )
1219
+ sizes , result = grb . size (), grb .nth (n )
1246
1220
mask = (sizes < max_len ).values
1247
1221
1248
1222
# set the results which don't meet the criteria
@@ -1380,11 +1354,8 @@ def head(self, n=5):
1380
1354
0 1 2
1381
1355
2 5 6
1382
1356
"""
1383
-
1384
- obj = self ._selected_obj
1385
- in_head = self ._cumcount_array () < n
1386
- head = obj [in_head ]
1387
- return head
1357
+ mask = self ._cumcount_array () < n
1358
+ return self ._selected_obj [mask ]
1388
1359
1389
1360
@Substitution (name = 'groupby' )
1390
1361
@Appender (_doc_template )
@@ -1409,12 +1380,8 @@ def tail(self, n=5):
1409
1380
0 a 1
1410
1381
2 b 1
1411
1382
"""
1412
-
1413
- obj = self ._selected_obj
1414
- rng = np .arange (0 , - self .grouper ._max_groupsize , - 1 , dtype = 'int64' )
1415
- in_tail = self ._cumcount_array (rng , ascending = False ) > - n
1416
- tail = obj [in_tail ]
1417
- return tail
1383
+ mask = self ._cumcount_array (ascending = False ) < n
1384
+ return self ._selected_obj [mask ]
1418
1385
1419
1386
1420
1387
@Appender (GroupBy .__doc__ )
0 commit comments