6
6
import warnings
7
7
import copy
8
8
from textwrap import dedent
9
+ from contextlib import contextmanager
9
10
10
11
from pandas .compat import (
11
12
zip , range , lzip ,
@@ -549,6 +550,16 @@ def f(self):
549
550
return attr
550
551
551
552
553
+ @contextmanager
554
+ def _group_selection_context (groupby ):
555
+ """
556
+ set / reset the _group_selection_context
557
+ """
558
+ groupby ._set_group_selection ()
559
+ yield groupby
560
+ groupby ._reset_group_selection ()
561
+
562
+
552
563
class _GroupBy (PandasObject , SelectionMixin ):
553
564
_group_selection = None
554
565
_apply_whitelist = frozenset ([])
@@ -696,26 +707,32 @@ def _reset_group_selection(self):
696
707
each group regardless of whether a group selection was previously set.
697
708
"""
698
709
if self ._group_selection is not None :
699
- self ._group_selection = None
700
710
# GH12839 clear cached selection too when changing group selection
711
+ self ._group_selection = None
701
712
self ._reset_cache ('_selected_obj' )
702
713
703
714
def _set_group_selection (self ):
704
715
"""
705
716
Create group based selection. Used when selection is not passed
706
717
directly but instead via a grouper.
718
+
719
+ NOTE: this should be paired with a call to _reset_group_selection
707
720
"""
708
721
grp = self .grouper
709
- if self .as_index and getattr (grp , 'groupings' , None ) is not None and \
710
- self .obj .ndim > 1 :
711
- ax = self .obj ._info_axis
712
- groupers = [g .name for g in grp .groupings
713
- if g .level is None and g .in_axis ]
722
+ if not (self .as_index and
723
+ getattr (grp , 'groupings' , None ) is not None and
724
+ self .obj .ndim > 1 and
725
+ self ._group_selection is None ):
726
+ return
727
+
728
+ ax = self .obj ._info_axis
729
+ groupers = [g .name for g in grp .groupings
730
+ if g .level is None and g .in_axis ]
714
731
715
- if len (groupers ):
716
- self . _group_selection = ax . difference ( Index ( groupers )). tolist ()
717
- # GH12839 clear selected obj cache when group selection changes
718
- self ._reset_cache ('_selected_obj' )
732
+ if len (groupers ):
733
+ # GH12839 clear selected obj cache when group selection changes
734
+ self . _group_selection = ax . difference ( Index ( groupers )). tolist ()
735
+ self ._reset_cache ('_selected_obj' )
719
736
720
737
def _set_result_index_ordered (self , result ):
721
738
# set the result index on the passed values object and
@@ -781,10 +798,10 @@ def _make_wrapper(self, name):
781
798
type (self ).__name__ ))
782
799
raise AttributeError (msg )
783
800
784
- # need to setup the selection
785
- # as are not passed directly but in the grouper
786
801
self ._set_group_selection ()
787
802
803
+ # need to setup the selection
804
+ # as are not passed directly but in the grouper
788
805
f = getattr (self ._selected_obj , name )
789
806
if not isinstance (f , types .MethodType ):
790
807
return self .apply (lambda self : getattr (self , name ))
@@ -897,7 +914,22 @@ def f(g):
897
914
898
915
# ignore SettingWithCopy here in case the user mutates
899
916
with option_context ('mode.chained_assignment' , None ):
900
- return self ._python_apply_general (f )
917
+ try :
918
+ result = self ._python_apply_general (f )
919
+ except Exception :
920
+
921
+ # gh-20949
922
+ # try again, with .apply acting as a filtering
923
+ # operation, by excluding the grouping column
924
+ # This would normally not be triggered
925
+ # except if the udf is trying an operation that
926
+ # fails on *some* columns, e.g. a numeric operation
927
+ # on a string grouper column
928
+
929
+ with _group_selection_context (self ):
930
+ return self ._python_apply_general (f )
931
+
932
+ return result
901
933
902
934
def _python_apply_general (self , f ):
903
935
keys , values , mutated = self .grouper .apply (f , self ._selected_obj ,
@@ -1275,9 +1307,9 @@ def mean(self, *args, **kwargs):
1275
1307
except GroupByError :
1276
1308
raise
1277
1309
except Exception : # pragma: no cover
1278
- self . _set_group_selection ()
1279
- f = lambda x : x .mean (axis = self .axis , ** kwargs )
1280
- return self ._python_agg_general (f )
1310
+ with _group_selection_context ( self ):
1311
+ f = lambda x : x .mean (axis = self .axis , ** kwargs )
1312
+ return self ._python_agg_general (f )
1281
1313
1282
1314
@Substitution (name = 'groupby' )
1283
1315
@Appender (_doc_template )
@@ -1293,13 +1325,12 @@ def median(self, **kwargs):
1293
1325
raise
1294
1326
except Exception : # pragma: no cover
1295
1327
1296
- self ._set_group_selection ()
1297
-
1298
1328
def f (x ):
1299
1329
if isinstance (x , np .ndarray ):
1300
1330
x = Series (x )
1301
1331
return x .median (axis = self .axis , ** kwargs )
1302
- return self ._python_agg_general (f )
1332
+ with _group_selection_context (self ):
1333
+ return self ._python_agg_general (f )
1303
1334
1304
1335
@Substitution (name = 'groupby' )
1305
1336
@Appender (_doc_template )
@@ -1336,9 +1367,9 @@ def var(self, ddof=1, *args, **kwargs):
1336
1367
if ddof == 1 :
1337
1368
return self ._cython_agg_general ('var' , ** kwargs )
1338
1369
else :
1339
- self ._set_group_selection ()
1340
1370
f = lambda x : x .var (ddof = ddof , ** kwargs )
1341
- return self ._python_agg_general (f )
1371
+ with _group_selection_context (self ):
1372
+ return self ._python_agg_general (f )
1342
1373
1343
1374
@Substitution (name = 'groupby' )
1344
1375
@Appender (_doc_template )
@@ -1384,6 +1415,7 @@ def f(self, **kwargs):
1384
1415
kwargs ['numeric_only' ] = numeric_only
1385
1416
if 'min_count' not in kwargs :
1386
1417
kwargs ['min_count' ] = min_count
1418
+
1387
1419
self ._set_group_selection ()
1388
1420
try :
1389
1421
return self ._cython_agg_general (
@@ -1453,11 +1485,11 @@ def ohlc(self):
1453
1485
1454
1486
@Appender (DataFrame .describe .__doc__ )
1455
1487
def describe (self , ** kwargs ):
1456
- self . _set_group_selection ()
1457
- result = self .apply (lambda x : x .describe (** kwargs ))
1458
- if self .axis == 1 :
1459
- return result .T
1460
- return result .unstack ()
1488
+ with _group_selection_context ( self ):
1489
+ result = self .apply (lambda x : x .describe (** kwargs ))
1490
+ if self .axis == 1 :
1491
+ return result .T
1492
+ return result .unstack ()
1461
1493
1462
1494
@Substitution (name = 'groupby' )
1463
1495
@Appender (_doc_template )
@@ -1778,13 +1810,12 @@ def ngroup(self, ascending=True):
1778
1810
.cumcount : Number the rows in each group.
1779
1811
"""
1780
1812
1781
- self ._set_group_selection ()
1782
-
1783
- index = self ._selected_obj .index
1784
- result = Series (self .grouper .group_info [0 ], index )
1785
- if not ascending :
1786
- result = self .ngroups - 1 - result
1787
- return result
1813
+ with _group_selection_context (self ):
1814
+ index = self ._selected_obj .index
1815
+ result = Series (self .grouper .group_info [0 ], index )
1816
+ if not ascending :
1817
+ result = self .ngroups - 1 - result
1818
+ return result
1788
1819
1789
1820
@Substitution (name = 'groupby' )
1790
1821
def cumcount (self , ascending = True ):
@@ -1835,11 +1866,10 @@ def cumcount(self, ascending=True):
1835
1866
.ngroup : Number the groups themselves.
1836
1867
"""
1837
1868
1838
- self ._set_group_selection ()
1839
-
1840
- index = self ._selected_obj .index
1841
- cumcounts = self ._cumcount_array (ascending = ascending )
1842
- return Series (cumcounts , index )
1869
+ with _group_selection_context (self ):
1870
+ index = self ._selected_obj .index
1871
+ cumcounts = self ._cumcount_array (ascending = ascending )
1872
+ return Series (cumcounts , index )
1843
1873
1844
1874
@Substitution (name = 'groupby' )
1845
1875
@Appender (_doc_template )
@@ -3768,7 +3798,6 @@ def nunique(self, dropna=True):
3768
3798
3769
3799
@Appender (Series .describe .__doc__ )
3770
3800
def describe (self , ** kwargs ):
3771
- self ._set_group_selection ()
3772
3801
result = self .apply (lambda x : x .describe (** kwargs ))
3773
3802
if self .axis == 1 :
3774
3803
return result .T
@@ -4411,6 +4440,7 @@ def transform(self, func, *args, **kwargs):
4411
4440
return self ._transform_general (func , * args , ** kwargs )
4412
4441
4413
4442
obj = self ._obj_with_exclusions
4443
+
4414
4444
# nuiscance columns
4415
4445
if not result .columns .equals (obj .columns ):
4416
4446
return self ._transform_general (func , * args , ** kwargs )
0 commit comments