Skip to content

Commit 445d1c6

Browse files
behzadnourijreback
authored andcommitted
PERF: improves performance in GroupBy.cumcount
closes #12839 closes #11039
1 parent 5ae1bd8 commit 445d1c6

File tree

3 files changed

+122
-82
lines changed

3 files changed

+122
-82
lines changed

doc/source/whatsnew/v0.18.1.txt

+75
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,79 @@ These changes conform sparse handling to return the correct types and work to ma
131131
- Bug in ``pd.concat()`` of ``SparseDataFrame`` may raise ``AttributeError`` (:issue:`12174`)
132132
- Bug in ``SparseArray.shift()`` may raise ``NameError`` or ``TypeError`` (:issue:`12908`)
133133

134+
.. _whatsnew_0181.api.groubynth:
135+
136+
``.groupby(..).nth()`` changes
137+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
138+
139+
The index in ``.groupby(..).nth()`` output is now more consistent when the ``as_index`` argument is passed (:issue:`11039`):
140+
141+
.. ipython:: python
142+
143+
df = DataFrame({'A' : ['a', 'b', 'a'],
144+
'B' : [1, 2, 3]})
145+
df
146+
147+
Previous Behavior:
148+
149+
.. code-block:: ipython
150+
151+
In [3]: df.groupby('A', as_index=True)['B'].nth(0)
152+
Out[3]:
153+
0 1
154+
1 2
155+
Name: B, dtype: int64
156+
157+
In [4]: df.groupby('A', as_index=False)['B'].nth(0)
158+
Out[4]:
159+
0 1
160+
1 2
161+
Name: B, dtype: int64
162+
163+
New Behavior:
164+
165+
.. ipython:: ipython
166+
167+
df.groupby('A', as_index=True)['B'].nth(0)
168+
df.groupby('A', as_index=False)['B'].nth(0)
169+
170+
Furthermore, previously, a ``.groupby`` would always sort, regardless if ``sort=False`` was passed with ``.nth()``.
171+
172+
.. ipython:: python
173+
174+
np.random.seed(1234)
175+
df = pd.DataFrame(np.random.randn(100, 2), columns=['a', 'b'])
176+
df['c'] = np.random.randint(0, 4, 100)
177+
178+
Previous Behavior:
179+
180+
.. code-block:: ipython
181+
182+
In [4]: df.groupby('c', sort=True).nth(1)
183+
Out[4]:
184+
a b
185+
c
186+
0 -0.334077 0.002118
187+
1 0.036142 -2.074978
188+
2 -0.720589 0.887163
189+
3 0.859588 -0.636524
190+
191+
In [5]: df.groupby('c', sort=False).nth(1)
192+
Out[5]:
193+
a b
194+
c
195+
0 -0.334077 0.002118
196+
1 0.036142 -2.074978
197+
2 -0.720589 0.887163
198+
3 0.859588 -0.636524
199+
200+
New Behavior:
201+
202+
.. ipython:: python
203+
204+
df.groupby('c', sort=True).nth(1)
205+
df.groupby('c', sort=False).nth(1)
206+
134207
.. _whatsnew_0181.api:
135208

136209
API changes
@@ -255,6 +328,8 @@ Performance Improvements
255328
~~~~~~~~~~~~~~~~~~~~~~~~
256329

257330
- Improved speed of SAS reader (:issue:`12656`)
331+
- Performance improvements in ``.groupby(..).cumcount()`` (:issue:`11039`)
332+
258333

259334
- Improved performance of ``DataFrame.to_sql`` when checking case sensitivity for tables. Now only checks if table has been created correctly when table name is not lower case. (:issue:`12876`)
260335

pandas/core/groupby.py

+36-69
Original file line numberDiff line numberDiff line change
@@ -653,37 +653,37 @@ def _iterate_slices(self):
653653
def transform(self, func, *args, **kwargs):
654654
raise AbstractMethodError(self)
655655

656-
def _cumcount_array(self, arr=None, ascending=True):
656+
def _cumcount_array(self, ascending=True):
657657
"""
658-
arr is where cumcount gets its values from
658+
Parameters
659+
----------
660+
ascending : bool, default True
661+
If False, number in reverse, from length of group - 1 to 0.
659662
660663
Note
661664
----
662665
this is currently implementing sort=False
663666
(though the default is sort=True) for groupby in general
664667
"""
665-
if arr is None:
666-
arr = np.arange(self.grouper._max_groupsize, dtype='int64')
667-
668-
len_index = len(self._selected_obj.index)
669-
cumcounts = np.zeros(len_index, dtype=arr.dtype)
670-
if not len_index:
671-
return cumcounts
668+
ids, _, ngroups = self.grouper.group_info
669+
sorter = _get_group_index_sorter(ids, ngroups)
670+
ids, count = ids[sorter], len(ids)
672671

673-
indices, values = [], []
674-
for v in self.indices.values():
675-
indices.append(v)
672+
if count == 0:
673+
return np.empty(0, dtype=np.int64)
676674

677-
if ascending:
678-
values.append(arr[:len(v)])
679-
else:
680-
values.append(arr[len(v) - 1::-1])
675+
run = np.r_[True, ids[:-1] != ids[1:]]
676+
rep = np.diff(np.r_[np.nonzero(run)[0], count])
677+
out = (~run).cumsum()
681678

682-
indices = np.concatenate(indices)
683-
values = np.concatenate(values)
684-
cumcounts[indices] = values
679+
if ascending:
680+
out -= np.repeat(out[run], rep)
681+
else:
682+
out = np.repeat(out[np.r_[run[1:], True]], rep) - out
685683

686-
return cumcounts
684+
rev = np.empty(count, dtype=np.intp)
685+
rev[sorter] = np.arange(count, dtype=np.intp)
686+
return out[rev].astype(np.int64, copy=False)
687687

688688
def _index_with_as_index(self, b):
689689
"""
@@ -1170,47 +1170,21 @@ def nth(self, n, dropna=None):
11701170
else:
11711171
raise TypeError("n needs to be an int or a list/set/tuple of ints")
11721172

1173-
m = self.grouper._max_groupsize
1174-
# filter out values that are outside [-m, m)
1175-
pos_nth_values = [i for i in nth_values if i >= 0 and i < m]
1176-
neg_nth_values = [i for i in nth_values if i < 0 and i >= -m]
1177-
1173+
nth_values = np.array(nth_values, dtype=np.intp)
11781174
self._set_selection_from_grouper()
1179-
if not dropna: # good choice
1180-
if not pos_nth_values and not neg_nth_values:
1181-
# no valid nth values
1182-
return self._selected_obj.loc[[]]
1183-
1184-
rng = np.zeros(m, dtype=bool)
1185-
for i in pos_nth_values:
1186-
rng[i] = True
1187-
is_nth = self._cumcount_array(rng)
11881175

1189-
if neg_nth_values:
1190-
rng = np.zeros(m, dtype=bool)
1191-
for i in neg_nth_values:
1192-
rng[- i - 1] = True
1193-
is_nth |= self._cumcount_array(rng, ascending=False)
1176+
if not dropna:
1177+
mask = np.in1d(self._cumcount_array(), nth_values) | \
1178+
np.in1d(self._cumcount_array(ascending=False) + 1, -nth_values)
11941179

1195-
result = self._selected_obj[is_nth]
1180+
out = self._selected_obj[mask]
1181+
if not self.as_index:
1182+
return out
11961183

1197-
# the result index
1198-
if self.as_index:
1199-
ax = self.obj._info_axis
1200-
names = self.grouper.names
1201-
if self.obj.ndim == 1:
1202-
# this is a pass-thru
1203-
pass
1204-
elif all([x in ax for x in names]):
1205-
indicies = [self.obj[name][is_nth] for name in names]
1206-
result.index = MultiIndex.from_arrays(
1207-
indicies).set_names(names)
1208-
elif self._group_selection is not None:
1209-
result.index = self.obj._get_axis(self.axis)[is_nth]
1210-
1211-
result = result.sort_index()
1184+
ids, _, _ = self.grouper.group_info
1185+
out.index = self.grouper.result_index[ids[mask]]
12121186

1213-
return result
1187+
return out.sort_index() if self.sort else out
12141188

12151189
if isinstance(self._selected_obj, DataFrame) and \
12161190
dropna not in ['any', 'all']:
@@ -1241,8 +1215,8 @@ def nth(self, n, dropna=None):
12411215
axis=self.axis, level=self.level,
12421216
sort=self.sort)
12431217

1244-
sizes = dropped.groupby(grouper).size()
1245-
result = dropped.groupby(grouper).nth(n)
1218+
grb = dropped.groupby(grouper, as_index=self.as_index, sort=self.sort)
1219+
sizes, result = grb.size(), grb.nth(n)
12461220
mask = (sizes < max_len).values
12471221

12481222
# set the results which don't meet the criteria
@@ -1380,11 +1354,8 @@ def head(self, n=5):
13801354
0 1 2
13811355
2 5 6
13821356
"""
1383-
1384-
obj = self._selected_obj
1385-
in_head = self._cumcount_array() < n
1386-
head = obj[in_head]
1387-
return head
1357+
mask = self._cumcount_array() < n
1358+
return self._selected_obj[mask]
13881359

13891360
@Substitution(name='groupby')
13901361
@Appender(_doc_template)
@@ -1409,12 +1380,8 @@ def tail(self, n=5):
14091380
0 a 1
14101381
2 b 1
14111382
"""
1412-
1413-
obj = self._selected_obj
1414-
rng = np.arange(0, -self.grouper._max_groupsize, -1, dtype='int64')
1415-
in_tail = self._cumcount_array(rng, ascending=False) > -n
1416-
tail = obj[in_tail]
1417-
return tail
1383+
mask = self._cumcount_array(ascending=False) < n
1384+
return self._selected_obj[mask]
14181385

14191386

14201387
@Appender(GroupBy.__doc__)

pandas/tests/test_groupby.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,7 @@ def test_first_last_nth(self):
167167
self.df.loc[self.df['A'] == 'foo', 'B'] = np.nan
168168
self.assertTrue(com.isnull(grouped['B'].first()['foo']))
169169
self.assertTrue(com.isnull(grouped['B'].last()['foo']))
170-
self.assertTrue(com.isnull(grouped['B'].nth(0)[0])
171-
) # not sure what this is testing
170+
self.assertTrue(com.isnull(grouped['B'].nth(0)['foo']))
172171

173172
# v0.14.0 whatsnew
174173
df = DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=['A', 'B'])
@@ -221,12 +220,12 @@ def test_nth(self):
221220

222221
assert_frame_equal(g.nth(0), df.iloc[[0, 2]].set_index('A'))
223222
assert_frame_equal(g.nth(1), df.iloc[[1]].set_index('A'))
224-
assert_frame_equal(g.nth(2), df.loc[[], ['B']])
223+
assert_frame_equal(g.nth(2), df.loc[[]].set_index('A'))
225224
assert_frame_equal(g.nth(-1), df.iloc[[1, 2]].set_index('A'))
226225
assert_frame_equal(g.nth(-2), df.iloc[[0]].set_index('A'))
227-
assert_frame_equal(g.nth(-3), df.loc[[], ['B']])
228-
assert_series_equal(g.B.nth(0), df.B.iloc[[0, 2]])
229-
assert_series_equal(g.B.nth(1), df.B.iloc[[1]])
226+
assert_frame_equal(g.nth(-3), df.loc[[]].set_index('A'))
227+
assert_series_equal(g.B.nth(0), df.set_index('A').B.iloc[[0, 2]])
228+
assert_series_equal(g.B.nth(1), df.set_index('A').B.iloc[[1]])
230229
assert_frame_equal(g[['B']].nth(0),
231230
df.ix[[0, 2], ['A', 'B']].set_index('A'))
232231

@@ -262,11 +261,11 @@ def test_nth(self):
262261
4: 0.70422799999999997}}).set_index(['color',
263262
'food'])
264263

265-
result = df.groupby(level=0).nth(2)
264+
result = df.groupby(level=0, as_index=False).nth(2)
266265
expected = df.iloc[[-1]]
267266
assert_frame_equal(result, expected)
268267

269-
result = df.groupby(level=0).nth(3)
268+
result = df.groupby(level=0, as_index=False).nth(3)
270269
expected = df.loc[[]]
271270
assert_frame_equal(result, expected)
272271

@@ -290,8 +289,7 @@ def test_nth(self):
290289
# as it keeps the order in the series (and not the group order)
291290
# related GH 7287
292291
expected = s.groupby(g, sort=False).first()
293-
expected.index = pd.Index(range(1, 10), name=0)
294-
result = s.groupby(g).nth(0, dropna='all')
292+
result = s.groupby(g, sort=False).nth(0, dropna='all')
295293
assert_series_equal(result, expected)
296294

297295
# doc example
@@ -316,14 +314,14 @@ def test_nth(self):
316314
assert_frame_equal(
317315
g.nth([0, 1, -1]), df.iloc[[0, 1, 2, 3, 4]].set_index('A'))
318316
assert_frame_equal(g.nth([2]), df.iloc[[2]].set_index('A'))
319-
assert_frame_equal(g.nth([3, 4]), df.loc[[], ['B']])
317+
assert_frame_equal(g.nth([3, 4]), df.loc[[]].set_index('A'))
320318

321319
business_dates = pd.date_range(start='4/1/2014', end='6/30/2014',
322320
freq='B')
323321
df = DataFrame(1, index=business_dates, columns=['a', 'b'])
324322
# get the first, fourth and last two business days for each month
325-
result = df.groupby((df.index.year, df.index.month)).nth([0, 3, -2, -1
326-
])
323+
key = (df.index.year, df.index.month)
324+
result = df.groupby(key, as_index=False).nth([0, 3, -2, -1])
327325
expected_dates = pd.to_datetime(
328326
['2014/4/1', '2014/4/4', '2014/4/29', '2014/4/30', '2014/5/1',
329327
'2014/5/6', '2014/5/29', '2014/5/30', '2014/6/2', '2014/6/5',

0 commit comments

Comments
 (0)