6
6
import warnings
7
7
import copy
8
8
from textwrap import dedent
9
+ from contextlib import contextmanager
9
10
10
11
from pandas .compat import (
11
12
zip , range , lzip ,
@@ -549,6 +550,16 @@ def f(self):
549
550
return attr
550
551
551
552
553
+ @contextmanager
554
+ def _group_selection_context (groupby ):
555
+ """
556
+ set / reset the _group_selection_context
557
+ """
558
+ groupby ._set_group_selection ()
559
+ yield groupby
560
+ groupby ._reset_group_selection ()
561
+
562
+
552
563
class _GroupBy (PandasObject , SelectionMixin ):
553
564
_group_selection = None
554
565
_apply_whitelist = frozenset ([])
@@ -704,6 +715,8 @@ def _set_group_selection(self):
704
715
"""
705
716
Create group based selection. Used when selection is not passed
706
717
directly but instead via a grouper.
718
+
719
+ NOTE: this should be paired with a call to _reset_group_selection
707
720
"""
708
721
grp = self .grouper
709
722
if not (self .as_index and
@@ -785,10 +798,10 @@ def _make_wrapper(self, name):
785
798
type (self ).__name__ ))
786
799
raise AttributeError (msg )
787
800
788
- # need to setup the selection
789
- # as are not passed directly but in the grouper
790
801
self ._set_group_selection ()
791
802
803
+ # need to setup the selection
804
+ # as are not passed directly but in the grouper
792
805
f = getattr (self ._selected_obj , name )
793
806
if not isinstance (f , types .MethodType ):
794
807
return self .apply (lambda self : getattr (self , name ))
@@ -913,9 +926,8 @@ def f(g):
913
926
# fails on *some* columns, e.g. a numeric operation
914
927
# on a string grouper column
915
928
916
- self ._set_group_selection ()
917
- result = self ._python_apply_general (f )
918
- self ._reset_group_selection ()
929
+ with _group_selection_context (self ):
930
+ return self ._python_apply_general (f )
919
931
920
932
return result
921
933
@@ -1295,9 +1307,9 @@ def mean(self, *args, **kwargs):
1295
1307
except GroupByError :
1296
1308
raise
1297
1309
except Exception : # pragma: no cover
1298
- self . _set_group_selection ()
1299
- f = lambda x : x .mean (axis = self .axis , ** kwargs )
1300
- return self ._python_agg_general (f )
1310
+ with _group_selection_context ( self ):
1311
+ f = lambda x : x .mean (axis = self .axis , ** kwargs )
1312
+ return self ._python_agg_general (f )
1301
1313
1302
1314
@Substitution (name = 'groupby' )
1303
1315
@Appender (_doc_template )
@@ -1313,13 +1325,12 @@ def median(self, **kwargs):
1313
1325
raise
1314
1326
except Exception : # pragma: no cover
1315
1327
1316
- self ._set_group_selection ()
1317
-
1318
1328
def f (x ):
1319
1329
if isinstance (x , np .ndarray ):
1320
1330
x = Series (x )
1321
1331
return x .median (axis = self .axis , ** kwargs )
1322
- return self ._python_agg_general (f )
1332
+ with _group_selection_context (self ):
1333
+ return self ._python_agg_general (f )
1323
1334
1324
1335
@Substitution (name = 'groupby' )
1325
1336
@Appender (_doc_template )
@@ -1356,9 +1367,9 @@ def var(self, ddof=1, *args, **kwargs):
1356
1367
if ddof == 1 :
1357
1368
return self ._cython_agg_general ('var' , ** kwargs )
1358
1369
else :
1359
- self ._set_group_selection ()
1360
1370
f = lambda x : x .var (ddof = ddof , ** kwargs )
1361
- return self ._python_agg_general (f )
1371
+ with _group_selection_context (self ):
1372
+ return self ._python_agg_general (f )
1362
1373
1363
1374
@Substitution (name = 'groupby' )
1364
1375
@Appender (_doc_template )
@@ -1404,6 +1415,7 @@ def f(self, **kwargs):
1404
1415
kwargs ['numeric_only' ] = numeric_only
1405
1416
if 'min_count' not in kwargs :
1406
1417
kwargs ['min_count' ] = min_count
1418
+
1407
1419
self ._set_group_selection ()
1408
1420
try :
1409
1421
return self ._cython_agg_general (
@@ -1797,13 +1809,12 @@ def ngroup(self, ascending=True):
1797
1809
.cumcount : Number the rows in each group.
1798
1810
"""
1799
1811
1800
- self ._set_group_selection ()
1801
-
1802
- index = self ._selected_obj .index
1803
- result = Series (self .grouper .group_info [0 ], index )
1804
- if not ascending :
1805
- result = self .ngroups - 1 - result
1806
- return result
1812
+ with _group_selection_context (self ):
1813
+ index = self ._selected_obj .index
1814
+ result = Series (self .grouper .group_info [0 ], index )
1815
+ if not ascending :
1816
+ result = self .ngroups - 1 - result
1817
+ return result
1807
1818
1808
1819
@Substitution (name = 'groupby' )
1809
1820
def cumcount (self , ascending = True ):
@@ -1854,11 +1865,10 @@ def cumcount(self, ascending=True):
1854
1865
.ngroup : Number the groups themselves.
1855
1866
"""
1856
1867
1857
- self ._set_group_selection ()
1858
-
1859
- index = self ._selected_obj .index
1860
- cumcounts = self ._cumcount_array (ascending = ascending )
1861
- return Series (cumcounts , index )
1868
+ with _group_selection_context (self ):
1869
+ index = self ._selected_obj .index
1870
+ cumcounts = self ._cumcount_array (ascending = ascending )
1871
+ return Series (cumcounts , index )
1862
1872
1863
1873
@Substitution (name = 'groupby' )
1864
1874
@Appender (_doc_template )
0 commit comments