Skip to content

Commit ac1851a

Browse files
committed
Added docstrings/comments, and new tests.
Small changes Added tests
1 parent 4d73cbf commit ac1851a

File tree

2 files changed

+38
-25
lines changed

2 files changed

+38
-25
lines changed

pandas/core/groupby.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _groupby_function(name, alias, npfunc, numeric_only=True,
9595
@Appender(_doc_template)
9696
@Appender(_local_template)
9797
def f(self):
98-
self._set_selection_from_grouper()
98+
self._set_group_selection()
9999
try:
100100
return self._cython_agg_general(alias, numeric_only=numeric_only)
101101
except AssertionError as e:
@@ -457,13 +457,21 @@ def _selected_obj(self):
457457
else:
458458
return self.obj[self._selection]
459459

460-
def _reset_group_selection(self):
460+
def _clear_group_selection(self):
461+
"""
462+
Clear group based selection. Used for methods needing to return info on
463+
each group regardless of whether a group selection was previously set.
464+
"""
461465
if self._group_selection is not None:
462466
self._group_selection = None
467+
# GH12839 clear cached selection too when changing group selection
463468
self._reset_cache('_selected_obj')
464469

465-
def _set_selection_from_grouper(self):
466-
""" we may need create a selection if we have non-level groupers """
470+
def _set_group_selection(self):
471+
"""
472+
Create group based selection. Used when selection is not passed
473+
directly but instead via a grouper.
474+
"""
467475
grp = self.grouper
468476
if self.as_index and getattr(grp, 'groupings', None) is not None and \
469477
self.obj.ndim > 1:
@@ -473,6 +481,7 @@ def _set_selection_from_grouper(self):
473481

474482
if len(groupers):
475483
self._group_selection = ax.difference(Index(groupers)).tolist()
484+
# GH12839 clear selected obj cache when group selection changes
476485
self._reset_cache('_selected_obj')
477486

478487
def _set_result_index_ordered(self, result):
@@ -517,7 +526,7 @@ def _make_wrapper(self, name):
517526

518527
# need to setup the selection
519528
# as are not passed directly but in the grouper
520-
self._set_selection_from_grouper()
529+
self._set_group_selection()
521530

522531
f = getattr(self._selected_obj, name)
523532
if not isinstance(f, types.MethodType):
@@ -985,7 +994,7 @@ def mean(self, *args, **kwargs):
985994
except GroupByError:
986995
raise
987996
except Exception: # pragma: no cover
988-
self._set_selection_from_grouper()
997+
self._set_group_selection()
989998
f = lambda x: x.mean(axis=self.axis)
990999
return self._python_agg_general(f)
9911000

@@ -1003,7 +1012,7 @@ def median(self):
10031012
raise
10041013
except Exception: # pragma: no cover
10051014

1006-
self._set_selection_from_grouper()
1015+
self._set_group_selection()
10071016

10081017
def f(x):
10091018
if isinstance(x, np.ndarray):
@@ -1046,7 +1055,7 @@ def var(self, ddof=1, *args, **kwargs):
10461055
if ddof == 1:
10471056
return self._cython_agg_general('var')
10481057
else:
1049-
self._set_selection_from_grouper()
1058+
self._set_group_selection()
10501059
f = lambda x: x.var(ddof=ddof)
10511060
return self._python_agg_general(f)
10521061

@@ -1222,7 +1231,7 @@ def nth(self, n, dropna=None):
12221231
raise TypeError("n needs to be an int or a list/set/tuple of ints")
12231232

12241233
nth_values = np.array(nth_values, dtype=np.intp)
1225-
self._set_selection_from_grouper()
1234+
self._set_group_selection()
12261235

12271236
if not dropna:
12281237
mask = np.in1d(self._cumcount_array(), nth_values) | \
@@ -1330,7 +1339,7 @@ def cumcount(self, ascending=True):
13301339
dtype: int64
13311340
"""
13321341

1333-
self._set_selection_from_grouper()
1342+
self._set_group_selection()
13341343

13351344
index = self._selected_obj.index
13361345
cumcounts = self._cumcount_array(ascending=ascending)
@@ -1408,7 +1417,7 @@ def head(self, n=5):
14081417
0 1 2
14091418
2 5 6
14101419
"""
1411-
self._reset_group_selection()
1420+
self._clear_group_selection()
14121421
mask = self._cumcount_array() < n
14131422
return self._selected_obj[mask]
14141423

@@ -1435,7 +1444,7 @@ def tail(self, n=5):
14351444
0 a 1
14361445
2 b 1
14371446
"""
1438-
self._reset_group_selection()
1447+
self._clear_group_selection()
14391448
mask = self._cumcount_array(ascending=False) < n
14401449
return self._selected_obj[mask]
14411450

pandas/tests/test_groupby.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -360,24 +360,28 @@ def test_group_selection_cache(self):
360360
expected = df.iloc[[0, 2]].set_index('A')
361361

362362
g = df.groupby('A')
363-
g.head()
364-
result = g.nth(0)
365-
assert_frame_equal(result, expected)
363+
result1 = g.head(n=2)
364+
result2 = g.nth(0)
365+
assert_frame_equal(result1, df)
366+
assert_frame_equal(result2, expected)
366367

367368
g = df.groupby('A')
368-
g.tail()
369-
result = g.nth(0)
370-
assert_frame_equal(result, expected)
369+
result1 = g.tail(n=2)
370+
result2 = g.nth(0)
371+
assert_frame_equal(result1, df)
372+
assert_frame_equal(result2, expected)
371373

372374
g = df.groupby('A')
373-
g.nth(0)
374-
result = g.head(n=2)
375-
assert_frame_equal(result, df)
375+
result1 = g.nth(0)
376+
result2 = g.head(n=2)
377+
assert_frame_equal(result1, expected)
378+
assert_frame_equal(result2, df)
376379

377380
g = df.groupby('A')
378-
g.nth(0)
379-
result = g.tail(n=2)
380-
assert_frame_equal(result, df)
381+
result1 = g.nth(0)
382+
result2 = g.tail(n=2)
383+
assert_frame_equal(result1, expected)
384+
assert_frame_equal(result2, df)
381385

382386
def test_grouper_index_types(self):
383387
# related GH5375
@@ -6132,7 +6136,7 @@ def test_cython_transform(self):
61326136
# bit a of hack to make sure the cythonized shift
61336137
# is equivalent to pre 0.17.1 behavior
61346138
if op == 'shift':
6135-
gb._set_selection_from_grouper()
6139+
gb._set_group_selection()
61366140

61376141
for (op, args), targop in ops:
61386142
if op != 'shift' and 'int' not in gb_target:

0 commit comments

Comments
 (0)