Skip to content

Commit cc0a188

Browse files
adneujreback
authored andcommitted
BUG: Groupby.nth includes group key inconsistently #12839
closes #12839 Author: adneu <[email protected]> Closes #13316 from adneu/12839 and squashes the following commits: 16f5cd3 [adneu] Name change ac1851a [adneu] Added docstrings/comments, and new tests. 4d73cbf [adneu] Updated tests 9b75df4 [adneu] BUG: Groupby.nth includes group key inconsistently #12839
1 parent d38ee27 commit cc0a188

File tree

3 files changed

+57
-11
lines changed

3 files changed

+57
-11
lines changed

doc/source/whatsnew/v0.18.2.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ Bug Fixes
521521

522522
- Bug in ``Series`` comparison operators when dealing with zero dim NumPy arrays (:issue:`13006`)
523523
- Bug in ``groupby`` where ``apply`` returns different result depending on whether first result is ``None`` or not (:issue:`12824`)
524-
524+
- Bug in ``groupby(..).nth()`` where the group key is included inconsistently if called after ``.head()/.tail()`` (:issue:`12839`)
525525

526526
- Bug in ``pd.to_numeric`` when ``errors='coerce'`` and input contains non-hashable objects (:issue:`13324`)
527527

pandas/core/groupby.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _groupby_function(name, alias, npfunc, numeric_only=True,
9595
@Appender(_doc_template)
9696
@Appender(_local_template)
9797
def f(self):
98-
self._set_selection_from_grouper()
98+
self._set_group_selection()
9999
try:
100100
return self._cython_agg_general(alias, numeric_only=numeric_only)
101101
except AssertionError as e:
@@ -457,8 +457,21 @@ def _selected_obj(self):
457457
else:
458458
return self.obj[self._selection]
459459

460-
def _set_selection_from_grouper(self):
461-
""" we may need create a selection if we have non-level groupers """
460+
def _reset_group_selection(self):
461+
"""
462+
Clear group based selection. Used for methods needing to return info on
463+
each group regardless of whether a group selection was previously set.
464+
"""
465+
if self._group_selection is not None:
466+
self._group_selection = None
467+
# GH12839 clear cached selection too when changing group selection
468+
self._reset_cache('_selected_obj')
469+
470+
def _set_group_selection(self):
471+
"""
472+
Create group based selection. Used when selection is not passed
473+
directly but instead via a grouper.
474+
"""
462475
grp = self.grouper
463476
if self.as_index and getattr(grp, 'groupings', None) is not None and \
464477
self.obj.ndim > 1:
@@ -468,6 +481,8 @@ def _set_selection_from_grouper(self):
468481

469482
if len(groupers):
470483
self._group_selection = ax.difference(Index(groupers)).tolist()
484+
# GH12839 clear selected obj cache when group selection changes
485+
self._reset_cache('_selected_obj')
471486

472487
def _set_result_index_ordered(self, result):
473488
# set the result index on the passed values object and
@@ -511,7 +526,7 @@ def _make_wrapper(self, name):
511526

512527
# need to setup the selection
513528
# as are not passed directly but in the grouper
514-
self._set_selection_from_grouper()
529+
self._set_group_selection()
515530

516531
f = getattr(self._selected_obj, name)
517532
if not isinstance(f, types.MethodType):
@@ -979,7 +994,7 @@ def mean(self, *args, **kwargs):
979994
except GroupByError:
980995
raise
981996
except Exception: # pragma: no cover
982-
self._set_selection_from_grouper()
997+
self._set_group_selection()
983998
f = lambda x: x.mean(axis=self.axis)
984999
return self._python_agg_general(f)
9851000

@@ -997,7 +1012,7 @@ def median(self):
9971012
raise
9981013
except Exception: # pragma: no cover
9991014

1000-
self._set_selection_from_grouper()
1015+
self._set_group_selection()
10011016

10021017
def f(x):
10031018
if isinstance(x, np.ndarray):
@@ -1040,7 +1055,7 @@ def var(self, ddof=1, *args, **kwargs):
10401055
if ddof == 1:
10411056
return self._cython_agg_general('var')
10421057
else:
1043-
self._set_selection_from_grouper()
1058+
self._set_group_selection()
10441059
f = lambda x: x.var(ddof=ddof)
10451060
return self._python_agg_general(f)
10461061

@@ -1217,7 +1232,7 @@ def nth(self, n, dropna=None):
12171232
raise TypeError("n needs to be an int or a list/set/tuple of ints")
12181233

12191234
nth_values = np.array(nth_values, dtype=np.intp)
1220-
self._set_selection_from_grouper()
1235+
self._set_group_selection()
12211236

12221237
if not dropna:
12231238
mask = np.in1d(self._cumcount_array(), nth_values) | \
@@ -1325,7 +1340,7 @@ def cumcount(self, ascending=True):
13251340
dtype: int64
13261341
"""
13271342

1328-
self._set_selection_from_grouper()
1343+
self._set_group_selection()
13291344

13301345
index = self._selected_obj.index
13311346
cumcounts = self._cumcount_array(ascending=ascending)
@@ -1403,6 +1418,7 @@ def head(self, n=5):
14031418
0 1 2
14041419
2 5 6
14051420
"""
1421+
self._reset_group_selection()
14061422
mask = self._cumcount_array() < n
14071423
return self._selected_obj[mask]
14081424

@@ -1429,6 +1445,7 @@ def tail(self, n=5):
14291445
0 a 1
14301446
2 b 1
14311447
"""
1448+
self._reset_group_selection()
14321449
mask = self._cumcount_array(ascending=False) < n
14331450
return self._selected_obj[mask]
14341451

pandas/tests/test_groupby.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,35 @@ def test_nth_multi_index_as_expected(self):
354354
names=['A', 'B']))
355355
assert_frame_equal(result, expected)
356356

357+
def test_group_selection_cache(self):
358+
# GH 12839 nth, head, and tail should return same result consistently
359+
df = DataFrame([[1, 2], [1, 4], [5, 6]], columns=['A', 'B'])
360+
expected = df.iloc[[0, 2]].set_index('A')
361+
362+
g = df.groupby('A')
363+
result1 = g.head(n=2)
364+
result2 = g.nth(0)
365+
assert_frame_equal(result1, df)
366+
assert_frame_equal(result2, expected)
367+
368+
g = df.groupby('A')
369+
result1 = g.tail(n=2)
370+
result2 = g.nth(0)
371+
assert_frame_equal(result1, df)
372+
assert_frame_equal(result2, expected)
373+
374+
g = df.groupby('A')
375+
result1 = g.nth(0)
376+
result2 = g.head(n=2)
377+
assert_frame_equal(result1, expected)
378+
assert_frame_equal(result2, df)
379+
380+
g = df.groupby('A')
381+
result1 = g.nth(0)
382+
result2 = g.tail(n=2)
383+
assert_frame_equal(result1, expected)
384+
assert_frame_equal(result2, df)
385+
357386
def test_grouper_index_types(self):
358387
# related GH5375
359388
# groupby misbehaving when using a Floatlike index
@@ -6116,7 +6145,7 @@ def test_cython_transform(self):
61166145
# bit a of hack to make sure the cythonized shift
61176146
# is equivalent to pre 0.17.1 behavior
61186147
if op == 'shift':
6119-
gb._set_selection_from_grouper()
6148+
gb._set_group_selection()
61206149

61216150
for (op, args), targop in ops:
61226151
if op != 'shift' and 'int' not in gb_target:

0 commit comments

Comments
 (0)