Skip to content

Commit a6e60a7

Browse files
committed
ENH: Add groupby().enumerate method to count groups (#11642)
1 parent 5d791cc commit a6e60a7

File tree

5 files changed

+200
-1
lines changed

5 files changed

+200
-1
lines changed

doc/source/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,7 @@ Computations / Descriptive Stats
16821682

16831683
GroupBy.count
16841684
GroupBy.cumcount
1685+
GroupBy.enumerate
16851686
GroupBy.first
16861687
GroupBy.head
16871688
GroupBy.last

doc/source/groupby.rst

+18-1
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,7 @@ Enumerate group items
969969
.. versionadded:: 0.13.0
970970

971971
To see the order in which each row appears within its group, use the
972-
``cumcount`` method:
972+
``cumcount`` method (compare with ``enumerate``):
973973

974974
.. ipython:: python
975975
@@ -980,6 +980,23 @@ To see the order in which each row appears within its group, use the
980980
981981
df.groupby('A').cumcount(ascending=False) # kwarg only
982982
983+
Enumerate groups
984+
~~~~~~~~~~~~~~~~
985+
986+
.. versionadded:: 0.19.0
987+
988+
To see the ordering of the groups themselves, you can use the ``enumerate``
989+
method (compare with ``cumcount``):
990+
991+
.. ipython:: python
992+
993+
df = pd.DataFrame(list('aaabba'), columns=['A'])
994+
df
995+
996+
df.groupby('A').enumerate()
997+
998+
df.groupby('A').enumerate(ascending=False) # kwarg only
999+
9831000
Plotting
9841001
~~~~~~~~
9851002

doc/source/whatsnew/v0.19.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ Other enhancements
389389

390390
- ``Categorical.astype()`` now accepts an optional boolean argument ``copy``, effective when dtype is categorical (:issue:`13209`)
391391
- ``DataFrame`` has gained the ``.asof()`` method to return the last non-NaN values according to the selected subset (:issue:`13358`)
392+
- A new groupby method ``enumerate``, parallel to the existing ``cumcount``, has been added to return the group order (:issue:`11642`)
392393
- Consistent with the Python API, ``pd.read_csv()`` will now interpret ``+inf`` as positive infinity (:issue:`13274`)
393394
- The ``DataFrame`` constructor will now respect key ordering if a list of ``OrderedDict`` objects are passed in (:issue:`13304`)
394395
- ``pd.read_html()`` has gained support for the ``decimal`` option (:issue:`12907`)

pandas/core/groupby.py

+68
Original file line numberDiff line numberDiff line change
@@ -1329,6 +1329,74 @@ def nth(self, n, dropna=None):
13291329

13301330
return result
13311331

1332+
@Substitution(name='groupby')
1333+
@Appender(_doc_template)
1334+
def enumerate(self, ascending=True):
1335+
"""
1336+
Number each group from 0 to the number of groups - 1.
1337+
1338+
This is the enumerative complement of cumcount. Note that the
1339+
numbers given to the groups match the order in which the groups
1340+
would be seen when iterating over the groupby object, not the
1341+
order they are first observed.
1342+
1343+
Parameters
1344+
----------
1345+
ascending : bool, default True
1346+
If False, number in reverse, from number of group - 1 to 0.
1347+
1348+
Examples
1349+
--------
1350+
1351+
>>> df = pd.DataFrame([['a'], ['a'], ['a'], ['b'], ['b'], ['a']],
1352+
... columns=['A'])
1353+
>>> df
1354+
A
1355+
0 a
1356+
1 a
1357+
2 a
1358+
3 b
1359+
4 b
1360+
5 a
1361+
>>> df.groupby('A').enumerate()
1362+
0 0
1363+
1 0
1364+
2 0
1365+
3 1
1366+
4 1
1367+
5 0
1368+
dtype: int64
1369+
>>> df.groupby('A').enumerate(ascending=False)
1370+
0 1
1371+
1 1
1372+
2 1
1373+
3 0
1374+
4 0
1375+
5 1
1376+
dtype: int64
1377+
>>> df = pd.DataFrame([['b'], ['a'], ['a'], ['b']], columns=['A'])
1378+
>>> df
1379+
A
1380+
0 b
1381+
1 a
1382+
2 a
1383+
3 b
1384+
>>> df.groupby("A").enumerate()
1385+
0 1
1386+
1 0
1387+
2 0
1388+
3 1
1389+
dtype: int64
1390+
"""
1391+
1392+
self._set_group_selection()
1393+
1394+
index = self._selected_obj.index
1395+
result = Series(self.grouper.group_info[0], index)
1396+
if not ascending:
1397+
result = self.ngroups - 1 - result
1398+
return result
1399+
13321400
@Substitution(name='groupby')
13331401
@Appender(_doc_template)
13341402
def cumcount(self, ascending=True):

pandas/tests/test_groupby.py

+112
Original file line numberDiff line numberDiff line change
@@ -5141,6 +5141,118 @@ def test_cumcount_groupby_not_col(self):
51415141
assert_series_equal(expected, g.cumcount())
51425142
assert_series_equal(expected, sg.cumcount())
51435143

5144+
def test_enumerate(self):
5145+
df = DataFrame([['a'], ['a'], ['a'], ['b'], ['a']], columns=['A'])
5146+
g = df.groupby('A')
5147+
sg = g.A
5148+
5149+
expected = Series([0, 0, 0, 1, 0])
5150+
5151+
assert_series_equal(expected, g.enumerate())
5152+
assert_series_equal(expected, sg.enumerate())
5153+
5154+
def test_enumerate_empty(self):
5155+
ge = DataFrame().groupby(level=0)
5156+
se = Series().groupby(level=0)
5157+
5158+
# edge case, as this is usually considered float
5159+
e = Series(dtype='int64')
5160+
5161+
assert_series_equal(e, ge.enumerate())
5162+
assert_series_equal(e, se.enumerate())
5163+
5164+
def test_enumerate_dupe_index(self):
5165+
df = DataFrame([['a'], ['a'], ['a'], ['b'], ['a']], columns=['A'],
5166+
index=[0] * 5)
5167+
g = df.groupby('A')
5168+
sg = g.A
5169+
5170+
expected = Series([0, 0, 0, 1, 0], index=[0] * 5)
5171+
5172+
assert_series_equal(expected, g.enumerate())
5173+
assert_series_equal(expected, sg.enumerate())
5174+
5175+
def test_enumerate_mi(self):
5176+
mi = MultiIndex.from_tuples([[0, 1], [1, 2], [2, 2], [2, 2], [1, 0]])
5177+
df = DataFrame([['a'], ['a'], ['a'], ['b'], ['a']], columns=['A'],
5178+
index=mi)
5179+
g = df.groupby('A')
5180+
sg = g.A
5181+
5182+
expected = Series([0, 0, 0, 1, 0], index=mi)
5183+
5184+
assert_series_equal(expected, g.enumerate())
5185+
assert_series_equal(expected, sg.enumerate())
5186+
5187+
def test_enumerate_groupby_not_col(self):
5188+
df = DataFrame([['a'], ['a'], ['a'], ['b'], ['a']], columns=['A'],
5189+
index=[0] * 5)
5190+
g = df.groupby([0, 0, 0, 1, 0])
5191+
sg = g.A
5192+
5193+
expected = Series([0, 0, 0, 1, 0], index=[0] * 5)
5194+
5195+
assert_series_equal(expected, g.enumerate())
5196+
assert_series_equal(expected, sg.enumerate())
5197+
5198+
def test_enumerate_descending(self):
5199+
df = DataFrame(['a', 'a', 'b', 'a', 'b'], columns=['A'])
5200+
g = df.groupby(['A'])
5201+
5202+
ascending = Series([0, 0, 1, 0, 1])
5203+
descending = Series([1, 1, 0, 1, 0])
5204+
5205+
assert_series_equal(descending, (g.ngroups - 1) - ascending)
5206+
assert_series_equal(ascending, g.enumerate(ascending=True))
5207+
assert_series_equal(descending, g.enumerate(ascending=False))
5208+
5209+
def test_enumerate_matches_cumcount(self):
5210+
# specific case
5211+
df = DataFrame([['a', 'x'], ['a', 'y'], ['b', 'x'],
5212+
['a', 'x'], ['b', 'y']], columns=['A', 'X'])
5213+
g = df.groupby(['A', 'X'])
5214+
5215+
g_enumerate = g.enumerate()
5216+
g_cumcount = g.cumcount()
5217+
expected_enumerate = pd.Series([0, 1, 2, 0, 3])
5218+
expected_cumcount = pd.Series([0, 0, 0, 1, 0])
5219+
5220+
assert_series_equal(g_enumerate, expected_enumerate)
5221+
assert_series_equal(g_cumcount, expected_cumcount)
5222+
5223+
def test_enumerate_cumcount_pair(self):
5224+
from itertools import product
5225+
5226+
# brute force comparison, inefficient but clear
5227+
for p in product(range(3), repeat=4):
5228+
df = DataFrame({'a': p})
5229+
g = df.groupby(['a'])
5230+
5231+
order = sorted(set(p))
5232+
enumerated = [order.index(val) for val in p]
5233+
cumcounted = [p[:i].count(val) for i, val in enumerate(p)]
5234+
5235+
assert_series_equal(g.enumerate(), pd.Series(enumerated))
5236+
assert_series_equal(g.cumcount(), pd.Series(cumcounted))
5237+
5238+
def test_enumerate_respects_groupby_order(self):
5239+
np.random.seed(0)
5240+
df = DataFrame({'a': np.random.choice(list('abcdef'), 100)})
5241+
for sort_flag in (False, True):
5242+
g = df.groupby(['a'], sort=sort_flag)
5243+
df['group_id'] = -1
5244+
df['group_index'] = -1
5245+
5246+
for i, (key, group) in enumerate(g):
5247+
df.loc[group.index, 'group_id'] = i
5248+
for j, ind in enumerate(group.index):
5249+
df.loc[ind, 'group_index'] = j
5250+
5251+
assert_series_equal(pd.Series(df['group_id'].values),
5252+
g.enumerate())
5253+
assert_series_equal(pd.Series(df['group_index'].values),
5254+
g.cumcount())
5255+
51445256
def test_filter_series(self):
51455257
s = pd.Series([1, 3, 20, 5, 22, 24, 7])
51465258
expected_odd = pd.Series([1, 3, 5, 7], index=[0, 1, 3, 6])

0 commit comments

Comments
 (0)