Skip to content

Commit 620784f

Browse files
authored
BUG in .groupby.apply when applying a function that has mixed data types and the user supplied function can fail on the grouping column (#20959)
1 parent e051303 commit 620784f

File tree

3 files changed

+83
-39
lines changed

3 files changed

+83
-39
lines changed

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,7 @@ Groupby/Resample/Rolling
13251325
- Bug in :func:`DataFrame.resample` that dropped timezone information (:issue:`13238`)
13261326
- Bug in :func:`DataFrame.groupby` where transformations using ``np.all`` and ``np.any`` were raising a ``ValueError`` (:issue:`20653`)
13271327
- Bug in :func:`DataFrame.resample` where ``ffill``, ``bfill``, ``pad``, ``backfill``, ``fillna``, ``interpolate``, and ``asfreq`` were ignoring ``loffset``. (:issue:`20744`)
1328+
- Bug in :func:`DataFrame.groupby` when applying a function that has mixed data types and the user supplied function can fail on the grouping column (:issue:`20949`)
13281329

13291330
Sparse
13301331
^^^^^^

pandas/core/groupby/groupby.py

+69-39
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import warnings
77
import copy
88
from textwrap import dedent
9+
from contextlib import contextmanager
910

1011
from pandas.compat import (
1112
zip, range, lzip,
@@ -549,6 +550,16 @@ def f(self):
549550
return attr
550551

551552

553+
@contextmanager
554+
def _group_selection_context(groupby):
555+
"""
556+
set / reset the _group_selection_context
557+
"""
558+
groupby._set_group_selection()
559+
yield groupby
560+
groupby._reset_group_selection()
561+
562+
552563
class _GroupBy(PandasObject, SelectionMixin):
553564
_group_selection = None
554565
_apply_whitelist = frozenset([])
@@ -696,26 +707,32 @@ def _reset_group_selection(self):
696707
each group regardless of whether a group selection was previously set.
697708
"""
698709
if self._group_selection is not None:
699-
self._group_selection = None
700710
# GH12839 clear cached selection too when changing group selection
711+
self._group_selection = None
701712
self._reset_cache('_selected_obj')
702713

703714
def _set_group_selection(self):
704715
"""
705716
Create group based selection. Used when selection is not passed
706717
directly but instead via a grouper.
718+
719+
NOTE: this should be paired with a call to _reset_group_selection
707720
"""
708721
grp = self.grouper
709-
if self.as_index and getattr(grp, 'groupings', None) is not None and \
710-
self.obj.ndim > 1:
711-
ax = self.obj._info_axis
712-
groupers = [g.name for g in grp.groupings
713-
if g.level is None and g.in_axis]
722+
if not (self.as_index and
723+
getattr(grp, 'groupings', None) is not None and
724+
self.obj.ndim > 1 and
725+
self._group_selection is None):
726+
return
727+
728+
ax = self.obj._info_axis
729+
groupers = [g.name for g in grp.groupings
730+
if g.level is None and g.in_axis]
714731

715-
if len(groupers):
716-
self._group_selection = ax.difference(Index(groupers)).tolist()
717-
# GH12839 clear selected obj cache when group selection changes
718-
self._reset_cache('_selected_obj')
732+
if len(groupers):
733+
# GH12839 clear selected obj cache when group selection changes
734+
self._group_selection = ax.difference(Index(groupers)).tolist()
735+
self._reset_cache('_selected_obj')
719736

720737
def _set_result_index_ordered(self, result):
721738
# set the result index on the passed values object and
@@ -781,10 +798,10 @@ def _make_wrapper(self, name):
781798
type(self).__name__))
782799
raise AttributeError(msg)
783800

784-
# need to setup the selection
785-
# as are not passed directly but in the grouper
786801
self._set_group_selection()
787802

803+
# need to setup the selection
804+
# as are not passed directly but in the grouper
788805
f = getattr(self._selected_obj, name)
789806
if not isinstance(f, types.MethodType):
790807
return self.apply(lambda self: getattr(self, name))
@@ -897,7 +914,22 @@ def f(g):
897914

898915
# ignore SettingWithCopy here in case the user mutates
899916
with option_context('mode.chained_assignment', None):
900-
return self._python_apply_general(f)
917+
try:
918+
result = self._python_apply_general(f)
919+
except Exception:
920+
921+
# gh-20949
922+
# try again, with .apply acting as a filtering
923+
# operation, by excluding the grouping column
924+
# This would normally not be triggered
925+
# except if the udf is trying an operation that
926+
# fails on *some* columns, e.g. a numeric operation
927+
# on a string grouper column
928+
929+
with _group_selection_context(self):
930+
return self._python_apply_general(f)
931+
932+
return result
901933

902934
def _python_apply_general(self, f):
903935
keys, values, mutated = self.grouper.apply(f, self._selected_obj,
@@ -1275,9 +1307,9 @@ def mean(self, *args, **kwargs):
12751307
except GroupByError:
12761308
raise
12771309
except Exception: # pragma: no cover
1278-
self._set_group_selection()
1279-
f = lambda x: x.mean(axis=self.axis, **kwargs)
1280-
return self._python_agg_general(f)
1310+
with _group_selection_context(self):
1311+
f = lambda x: x.mean(axis=self.axis, **kwargs)
1312+
return self._python_agg_general(f)
12811313

12821314
@Substitution(name='groupby')
12831315
@Appender(_doc_template)
@@ -1293,13 +1325,12 @@ def median(self, **kwargs):
12931325
raise
12941326
except Exception: # pragma: no cover
12951327

1296-
self._set_group_selection()
1297-
12981328
def f(x):
12991329
if isinstance(x, np.ndarray):
13001330
x = Series(x)
13011331
return x.median(axis=self.axis, **kwargs)
1302-
return self._python_agg_general(f)
1332+
with _group_selection_context(self):
1333+
return self._python_agg_general(f)
13031334

13041335
@Substitution(name='groupby')
13051336
@Appender(_doc_template)
@@ -1336,9 +1367,9 @@ def var(self, ddof=1, *args, **kwargs):
13361367
if ddof == 1:
13371368
return self._cython_agg_general('var', **kwargs)
13381369
else:
1339-
self._set_group_selection()
13401370
f = lambda x: x.var(ddof=ddof, **kwargs)
1341-
return self._python_agg_general(f)
1371+
with _group_selection_context(self):
1372+
return self._python_agg_general(f)
13421373

13431374
@Substitution(name='groupby')
13441375
@Appender(_doc_template)
@@ -1384,6 +1415,7 @@ def f(self, **kwargs):
13841415
kwargs['numeric_only'] = numeric_only
13851416
if 'min_count' not in kwargs:
13861417
kwargs['min_count'] = min_count
1418+
13871419
self._set_group_selection()
13881420
try:
13891421
return self._cython_agg_general(
@@ -1453,11 +1485,11 @@ def ohlc(self):
14531485

14541486
@Appender(DataFrame.describe.__doc__)
14551487
def describe(self, **kwargs):
1456-
self._set_group_selection()
1457-
result = self.apply(lambda x: x.describe(**kwargs))
1458-
if self.axis == 1:
1459-
return result.T
1460-
return result.unstack()
1488+
with _group_selection_context(self):
1489+
result = self.apply(lambda x: x.describe(**kwargs))
1490+
if self.axis == 1:
1491+
return result.T
1492+
return result.unstack()
14611493

14621494
@Substitution(name='groupby')
14631495
@Appender(_doc_template)
@@ -1778,13 +1810,12 @@ def ngroup(self, ascending=True):
17781810
.cumcount : Number the rows in each group.
17791811
"""
17801812

1781-
self._set_group_selection()
1782-
1783-
index = self._selected_obj.index
1784-
result = Series(self.grouper.group_info[0], index)
1785-
if not ascending:
1786-
result = self.ngroups - 1 - result
1787-
return result
1813+
with _group_selection_context(self):
1814+
index = self._selected_obj.index
1815+
result = Series(self.grouper.group_info[0], index)
1816+
if not ascending:
1817+
result = self.ngroups - 1 - result
1818+
return result
17881819

17891820
@Substitution(name='groupby')
17901821
def cumcount(self, ascending=True):
@@ -1835,11 +1866,10 @@ def cumcount(self, ascending=True):
18351866
.ngroup : Number the groups themselves.
18361867
"""
18371868

1838-
self._set_group_selection()
1839-
1840-
index = self._selected_obj.index
1841-
cumcounts = self._cumcount_array(ascending=ascending)
1842-
return Series(cumcounts, index)
1869+
with _group_selection_context(self):
1870+
index = self._selected_obj.index
1871+
cumcounts = self._cumcount_array(ascending=ascending)
1872+
return Series(cumcounts, index)
18431873

18441874
@Substitution(name='groupby')
18451875
@Appender(_doc_template)
@@ -3768,7 +3798,6 @@ def nunique(self, dropna=True):
37683798

37693799
@Appender(Series.describe.__doc__)
37703800
def describe(self, **kwargs):
3771-
self._set_group_selection()
37723801
result = self.apply(lambda x: x.describe(**kwargs))
37733802
if self.axis == 1:
37743803
return result.T
@@ -4411,6 +4440,7 @@ def transform(self, func, *args, **kwargs):
44114440
return self._transform_general(func, *args, **kwargs)
44124441

44134442
obj = self._obj_with_exclusions
4443+
44144444
# nuiscance columns
44154445
if not result.columns.equals(obj.columns):
44164446
return self._transform_general(func, *args, **kwargs)

pandas/tests/groupby/test_apply.py

+13
Original file line numberDiff line numberDiff line change
@@ -515,3 +515,16 @@ def test_func(x):
515515
index=index2)
516516
tm.assert_frame_equal(result1, expected1)
517517
tm.assert_frame_equal(result2, expected2)
518+
519+
520+
def test_apply_with_mixed_types():
521+
# gh-20949
522+
df = pd.DataFrame({'A': 'a a b'.split(), 'B': [1, 2, 3], 'C': [4, 6, 5]})
523+
g = df.groupby('A')
524+
525+
result = g.transform(lambda x: x / x.sum())
526+
expected = pd.DataFrame({'B': [1 / 3., 2 / 3., 1], 'C': [0.4, 0.6, 1.0]})
527+
tm.assert_frame_equal(result, expected)
528+
529+
result = g.apply(lambda x: x / x.sum())
530+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)