Skip to content

Commit 881c7e1

Browse files
committed
use a context manager
1 parent 41d930b commit 881c7e1

File tree

2 files changed

+37
-27
lines changed

2 files changed

+37
-27
lines changed

pandas/core/groupby/groupby.py

+35-25
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import warnings
77
import copy
88
from textwrap import dedent
9+
from contextlib import contextmanager
910

1011
from pandas.compat import (
1112
zip, range, lzip,
@@ -549,6 +550,16 @@ def f(self):
549550
return attr
550551

551552

553+
@contextmanager
554+
def _group_selection_context(groupby):
555+
"""
556+
set / reset the _group_selection_context
557+
"""
558+
groupby._set_group_selection()
559+
yield groupby
560+
groupby._reset_group_selection()
561+
562+
552563
class _GroupBy(PandasObject, SelectionMixin):
553564
_group_selection = None
554565
_apply_whitelist = frozenset([])
@@ -704,6 +715,8 @@ def _set_group_selection(self):
704715
"""
705716
Create group based selection. Used when selection is not passed
706717
directly but instead via a grouper.
718+
719+
NOTE: this should be paired with a call to _reset_group_selection
707720
"""
708721
grp = self.grouper
709722
if not (self.as_index and
@@ -785,10 +798,10 @@ def _make_wrapper(self, name):
785798
type(self).__name__))
786799
raise AttributeError(msg)
787800

788-
# need to setup the selection
789-
# as are not passed directly but in the grouper
790801
self._set_group_selection()
791802

803+
# need to setup the selection
804+
# as are not passed directly but in the grouper
792805
f = getattr(self._selected_obj, name)
793806
if not isinstance(f, types.MethodType):
794807
return self.apply(lambda self: getattr(self, name))
@@ -913,9 +926,8 @@ def f(g):
913926
# fails on *some* columns, e.g. a numeric operation
914927
# on a string grouper column
915928

916-
self._set_group_selection()
917-
result = self._python_apply_general(f)
918-
self._reset_group_selection()
929+
with _group_selection_context(self):
930+
return self._python_apply_general(f)
919931

920932
return result
921933

@@ -1295,9 +1307,9 @@ def mean(self, *args, **kwargs):
12951307
except GroupByError:
12961308
raise
12971309
except Exception: # pragma: no cover
1298-
self._set_group_selection()
1299-
f = lambda x: x.mean(axis=self.axis, **kwargs)
1300-
return self._python_agg_general(f)
1310+
with _group_selection_context(self):
1311+
f = lambda x: x.mean(axis=self.axis, **kwargs)
1312+
return self._python_agg_general(f)
13011313

13021314
@Substitution(name='groupby')
13031315
@Appender(_doc_template)
@@ -1313,13 +1325,12 @@ def median(self, **kwargs):
13131325
raise
13141326
except Exception: # pragma: no cover
13151327

1316-
self._set_group_selection()
1317-
13181328
def f(x):
13191329
if isinstance(x, np.ndarray):
13201330
x = Series(x)
13211331
return x.median(axis=self.axis, **kwargs)
1322-
return self._python_agg_general(f)
1332+
with _group_selection_context(self):
1333+
return self._python_agg_general(f)
13231334

13241335
@Substitution(name='groupby')
13251336
@Appender(_doc_template)
@@ -1356,9 +1367,9 @@ def var(self, ddof=1, *args, **kwargs):
13561367
if ddof == 1:
13571368
return self._cython_agg_general('var', **kwargs)
13581369
else:
1359-
self._set_group_selection()
13601370
f = lambda x: x.var(ddof=ddof, **kwargs)
1361-
return self._python_agg_general(f)
1371+
with _group_selection_context(self):
1372+
return self._python_agg_general(f)
13621373

13631374
@Substitution(name='groupby')
13641375
@Appender(_doc_template)
@@ -1404,6 +1415,7 @@ def f(self, **kwargs):
14041415
kwargs['numeric_only'] = numeric_only
14051416
if 'min_count' not in kwargs:
14061417
kwargs['min_count'] = min_count
1418+
14071419
self._set_group_selection()
14081420
try:
14091421
return self._cython_agg_general(
@@ -1797,13 +1809,12 @@ def ngroup(self, ascending=True):
17971809
.cumcount : Number the rows in each group.
17981810
"""
17991811

1800-
self._set_group_selection()
1801-
1802-
index = self._selected_obj.index
1803-
result = Series(self.grouper.group_info[0], index)
1804-
if not ascending:
1805-
result = self.ngroups - 1 - result
1806-
return result
1812+
with _group_selection_context(self):
1813+
index = self._selected_obj.index
1814+
result = Series(self.grouper.group_info[0], index)
1815+
if not ascending:
1816+
result = self.ngroups - 1 - result
1817+
return result
18071818

18081819
@Substitution(name='groupby')
18091820
def cumcount(self, ascending=True):
@@ -1854,11 +1865,10 @@ def cumcount(self, ascending=True):
18541865
.ngroup : Number the groups themselves.
18551866
"""
18561867

1857-
self._set_group_selection()
1858-
1859-
index = self._selected_obj.index
1860-
cumcounts = self._cumcount_array(ascending=ascending)
1861-
return Series(cumcounts, index)
1868+
with _group_selection_context(self):
1869+
index = self._selected_obj.index
1870+
cumcounts = self._cumcount_array(ascending=ascending)
1871+
return Series(cumcounts, index)
18621872

18631873
@Substitution(name='groupby')
18641874
@Appender(_doc_template)

pandas/tests/groupby/test_apply.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -519,11 +519,11 @@ def test_func(x):
519519

520520
def test_apply_with_mixed_types():
521521
# gh-20949
522-
df = pd.DataFrame({'A': 'a a b'.split(), 'B': [1,2,3], 'C': [4, 6, 5]})
522+
df = pd.DataFrame({'A': 'a a b'.split(), 'B': [1, 2, 3], 'C': [4, 6, 5]})
523523
g = df.groupby('A')
524524

525525
result = g.transform(lambda x: x / x.sum())
526-
expected = pd.DataFrame({'B': [1/3., 2/3., 1], 'C': [0.4, 0.6, 1.0]})
526+
expected = pd.DataFrame({'B': [1 / 3., 2 / 3., 1], 'C': [0.4, 0.6, 1.0]})
527527
tm.assert_frame_equal(result, expected)
528528

529529
result = g.apply(lambda x: x / x.sum())

0 commit comments

Comments
 (0)