Skip to content

Commit e5e53ba

Browse files
committed
Merge pull request #5533 from hayd/groupby_head_tail
PERF faster head, tail and size groupby methods
2 parents b3e3bf1 + ef38319 commit e5e53ba

File tree

2 files changed

+169
-13
lines changed

2 files changed

+169
-13
lines changed

pandas/core/groupby.py

+123-7
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252

5353
_apply_whitelist = frozenset(['last', 'first',
5454
'mean', 'sum', 'min', 'max',
55-
'head', 'tail',
5655
'cumsum', 'cumprod', 'cummin', 'cummax',
5756
'resample',
5857
'describe',
@@ -482,13 +481,19 @@ def picker(arr):
482481
return np.nan
483482
return self.agg(picker)
484483

485-
def cumcount(self):
486-
"""Number each item in each group from 0 to the length of that group.
484+
def cumcount(self, **kwargs):
485+
"""
486+
Number each item in each group from 0 to the length of that group - 1.
487487
488488
Essentially this is equivalent to
489489
490490
>>> self.apply(lambda x: Series(np.arange(len(x)), x.index))
491491
492+
Parameters
493+
----------
494+
ascending : bool, default True
495+
If False, number in reverse, from length of group - 1 to 0.
496+
492497
Example
493498
-------
494499
@@ -510,14 +515,111 @@ def cumcount(self):
510515
4 1
511516
5 3
512517
dtype: int64
518+
>>> df.groupby('A').cumcount(ascending=False)
519+
0 3
520+
1 2
521+
2 1
522+
3 1
523+
4 0
524+
5 0
525+
dtype: int64
513526
514527
"""
528+
ascending = kwargs.pop('ascending', True)
529+
515530
index = self.obj.index
516-
cumcounts = np.zeros(len(index), dtype='int64')
517-
for v in self.indices.values():
518-
cumcounts[v] = np.arange(len(v), dtype='int64')
531+
rng = np.arange(self.grouper._max_groupsize, dtype='int64')
532+
cumcounts = self._cumcount_array(rng, ascending=ascending)
519533
return Series(cumcounts, index)
520534

535+
def head(self, n=5):
536+
"""
537+
Returns first n rows of each group.
538+
539+
Essentially equivalent to ``.apply(lambda x: x.head(n))``
540+
541+
Example
542+
-------
543+
544+
>>> df = DataFrame([[1, 2], [1, 4], [5, 6]],
545+
columns=['A', 'B'])
546+
>>> df.groupby('A', as_index=False).head(1)
547+
A B
548+
0 1 2
549+
2 5 6
550+
>>> df.groupby('A').head(1)
551+
A B
552+
A
553+
1 0 1 2
554+
5 2 5 6
555+
556+
"""
557+
rng = np.arange(self.grouper._max_groupsize, dtype='int64')
558+
in_head = self._cumcount_array(rng) < n
559+
head = self.obj[in_head]
560+
if self.as_index:
561+
head.index = self._index_with_as_index(in_head)
562+
return head
563+
564+
def tail(self, n=5):
565+
"""
566+
Returns last n rows of each group
567+
568+
Essentially equivalent to ``.apply(lambda x: x.tail(n))``
569+
570+
Example
571+
-------
572+
573+
>>> df = DataFrame([[1, 2], [1, 4], [5, 6]],
574+
columns=['A', 'B'])
575+
>>> df.groupby('A', as_index=False).tail(1)
576+
A B
577+
0 1 2
578+
2 5 6
579+
>>> df.groupby('A').head(1)
580+
A B
581+
A
582+
1 0 1 2
583+
5 2 5 6
584+
585+
"""
586+
rng = np.arange(0, -self.grouper._max_groupsize, -1, dtype='int64')
587+
in_tail = self._cumcount_array(rng, ascending=False) > -n
588+
tail = self.obj[in_tail]
589+
if self.as_index:
590+
tail.index = self._index_with_as_index(in_tail)
591+
return tail
592+
593+
def _cumcount_array(self, arr, **kwargs):
594+
ascending = kwargs.pop('ascending', True)
595+
596+
len_index = len(self.obj.index)
597+
cumcounts = np.zeros(len_index, dtype='int64')
598+
if ascending:
599+
for v in self.indices.values():
600+
cumcounts[v] = arr[:len(v)]
601+
else:
602+
for v in self.indices.values():
603+
cumcounts[v] = arr[len(v)-1::-1]
604+
return cumcounts
605+
606+
def _index_with_as_index(self, b):
607+
"""
608+
Take boolean mask of index to be returned from apply, if as_index=True
609+
610+
"""
611+
# TODO perf, it feels like this should already be somewhere...
612+
from itertools import chain
613+
original = self.obj.index
614+
gp = self.grouper
615+
levels = chain((gp.levels[i][gp.labels[i][b]]
616+
for i in range(len(gp.groupings))),
617+
(original.get_level_values(i)[b]
618+
for i in range(original.nlevels)))
619+
new = MultiIndex.from_arrays(list(levels))
620+
new.names = gp.names + original.names
621+
return new
622+
521623
def _try_cast(self, result, obj):
522624
"""
523625
try to cast the result to our obj original type,
@@ -758,14 +860,28 @@ def names(self):
758860
def size(self):
759861
"""
760862
Compute group sizes
863+
761864
"""
762865
# TODO: better impl
763866
labels, _, ngroups = self.group_info
764-
bin_counts = Series(labels).value_counts()
867+
bin_counts = algos.value_counts(labels, sort=False)
765868
bin_counts = bin_counts.reindex(np.arange(ngroups))
766869
bin_counts.index = self.result_index
767870
return bin_counts
768871

872+
@cache_readonly
873+
def _max_groupsize(self):
874+
'''
875+
Compute size of largest group
876+
877+
'''
878+
# For many items in each group this is much faster than
879+
# self.size().max(), in worst case marginally slower
880+
if self.indices:
881+
return max(len(v) for v in self.indices.values())
882+
else:
883+
return 0
884+
769885
@cache_readonly
770886
def groups(self):
771887
if len(self.groupings) == 1:

pandas/tests/test_groupby.py

+46-6
Original file line numberDiff line numberDiff line change
@@ -1203,24 +1203,64 @@ def test_groupby_as_index_apply(self):
12031203
g_not_as = df.groupby('user_id', as_index=False)
12041204

12051205
res_as = g_as.head(2).index
1206-
exp_as = MultiIndex.from_tuples([(1, 0), (1, 2), (2, 1), (3, 4)])
1206+
exp_as = MultiIndex.from_tuples([(1, 0), (2, 1), (1, 2), (3, 4)])
12071207
assert_index_equal(res_as, exp_as)
12081208

12091209
res_not_as = g_not_as.head(2).index
1210-
exp_not_as = Index([0, 2, 1, 4])
1210+
exp_not_as = Index([0, 1, 2, 4])
12111211
assert_index_equal(res_not_as, exp_not_as)
12121212

1213-
res_as = g_as.apply(lambda x: x.head(2)).index
1214-
assert_index_equal(res_not_as, exp_not_as)
1213+
res_as_apply = g_as.apply(lambda x: x.head(2)).index
1214+
res_not_as_apply = g_not_as.apply(lambda x: x.head(2)).index
12151215

1216-
res_not_as = g_not_as.apply(lambda x: x.head(2)).index
1217-
assert_index_equal(res_not_as, exp_not_as)
1216+
# apply doesn't maintain the original ordering
1217+
exp_not_as_apply = Index([0, 2, 1, 4])
1218+
exp_as_apply = MultiIndex.from_tuples([(1, 0), (1, 2), (2, 1), (3, 4)])
1219+
1220+
assert_index_equal(res_as_apply, exp_as_apply)
1221+
assert_index_equal(res_not_as_apply, exp_not_as_apply)
12181222

12191223
ind = Index(list('abcde'))
12201224
df = DataFrame([[1, 2], [2, 3], [1, 4], [1, 5], [2, 6]], index=ind)
12211225
res = df.groupby(0, as_index=False).apply(lambda x: x).index
12221226
assert_index_equal(res, ind)
12231227

1228+
def test_groupby_head_tail(self):
1229+
df = DataFrame([[1, 2], [1, 4], [5, 6]], columns=['A', 'B'])
1230+
g_as = df.groupby('A', as_index=True)
1231+
g_not_as = df.groupby('A', as_index=False)
1232+
1233+
# as_index= False, much easier
1234+
assert_frame_equal(df.loc[[0, 2]], g_not_as.head(1))
1235+
assert_frame_equal(df.loc[[1, 2]], g_not_as.tail(1))
1236+
1237+
empty_not_as = DataFrame(columns=df.columns)
1238+
assert_frame_equal(empty_not_as, g_not_as.head(0))
1239+
assert_frame_equal(empty_not_as, g_not_as.tail(0))
1240+
assert_frame_equal(empty_not_as, g_not_as.head(-1))
1241+
assert_frame_equal(empty_not_as, g_not_as.tail(-1))
1242+
1243+
assert_frame_equal(df, g_not_as.head(7)) # contains all
1244+
assert_frame_equal(df, g_not_as.tail(7))
1245+
1246+
# as_index=True, yuck
1247+
# prepend the A column as an index, in a roundabout way
1248+
df_as = df.copy()
1249+
df_as.index = df.set_index('A', append=True,
1250+
drop=False).index.swaplevel(0, 1)
1251+
1252+
assert_frame_equal(df_as.loc[[0, 2]], g_as.head(1))
1253+
assert_frame_equal(df_as.loc[[1, 2]], g_as.tail(1))
1254+
1255+
empty_as = DataFrame(index=df_as.index[:0], columns=df.columns)
1256+
assert_frame_equal(empty_as, g_as.head(0))
1257+
assert_frame_equal(empty_as, g_as.tail(0))
1258+
assert_frame_equal(empty_as, g_as.head(-1))
1259+
assert_frame_equal(empty_as, g_as.tail(-1))
1260+
1261+
assert_frame_equal(df_as, g_as.head(7)) # contains all
1262+
assert_frame_equal(df_as, g_as.tail(7))
1263+
12241264
def test_groupby_multiple_key(self):
12251265
df = tm.makeTimeDataFrame()
12261266
grouped = df.groupby([lambda x: x.year,

0 commit comments

Comments
 (0)