Skip to content

Commit dfa546e

Browse files
authored
BUG: GroupBy.apply() returns different results if a different GroupBy method is called first (#35314)
1 parent ce03883 commit dfa546e

File tree

5 files changed

+82
-50
lines changed

5 files changed

+82
-50
lines changed

doc/source/whatsnew/v1.2.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ Groupby/resample/rolling
156156
- Bug in :meth:`DataFrameGroupBy.apply` that would some times throw an erroneous ``ValueError`` if the grouping axis had duplicate entries (:issue:`16646`)
157157
-
158158
-
159-
159+
- Bug in :meth:`DataFrameGroupBy.apply` where a non-nuisance grouping column would be dropped from the output columns if another groupby method was called before ``.apply()`` (:issue:`34656`)
160160

161161
Reshaping
162162
^^^^^^^^^

pandas/core/groupby/groupby.py

+45-44
Original file line numberDiff line numberDiff line change
@@ -736,13 +736,12 @@ def pipe(self, func, *args, **kwargs):
736736
def _make_wrapper(self, name):
737737
assert name in self._apply_allowlist
738738

739-
self._set_group_selection()
740-
741-
# need to setup the selection
742-
# as are not passed directly but in the grouper
743-
f = getattr(self._obj_with_exclusions, name)
744-
if not isinstance(f, types.MethodType):
745-
return self.apply(lambda self: getattr(self, name))
739+
with _group_selection_context(self):
740+
# need to setup the selection
741+
# as are not passed directly but in the grouper
742+
f = getattr(self._obj_with_exclusions, name)
743+
if not isinstance(f, types.MethodType):
744+
return self.apply(lambda self: getattr(self, name))
746745

747746
f = getattr(type(self._obj_with_exclusions), name)
748747
sig = inspect.signature(f)
@@ -992,28 +991,28 @@ def _agg_general(
992991
alias: str,
993992
npfunc: Callable,
994993
):
995-
self._set_group_selection()
996-
997-
# try a cython aggregation if we can
998-
try:
999-
return self._cython_agg_general(
1000-
how=alias, alt=npfunc, numeric_only=numeric_only, min_count=min_count,
1001-
)
1002-
except DataError:
1003-
pass
1004-
except NotImplementedError as err:
1005-
if "function is not implemented for this dtype" in str(
1006-
err
1007-
) or "category dtype not supported" in str(err):
1008-
# raised in _get_cython_function, in some cases can
1009-
# be trimmed by implementing cython funcs for more dtypes
994+
with _group_selection_context(self):
995+
# try a cython aggregation if we can
996+
try:
997+
return self._cython_agg_general(
998+
how=alias,
999+
alt=npfunc,
1000+
numeric_only=numeric_only,
1001+
min_count=min_count,
1002+
)
1003+
except DataError:
10101004
pass
1011-
else:
1012-
raise
1013-
1014-
# apply a non-cython aggregation
1015-
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
1016-
return result
1005+
except NotImplementedError as err:
1006+
if "function is not implemented for this dtype" in str(
1007+
err
1008+
) or "category dtype not supported" in str(err):
1009+
# raised in _get_cython_function, in some cases can
1010+
# be trimmed by implementing cython funcs for more dtypes
1011+
pass
1012+
1013+
# apply a non-cython aggregation
1014+
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
1015+
return result
10171016

10181017
def _cython_agg_general(
10191018
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
@@ -1940,29 +1939,31 @@ def nth(self, n: Union[int, List[int]], dropna: Optional[str] = None) -> DataFra
19401939
nth_values = list(set(n))
19411940

19421941
nth_array = np.array(nth_values, dtype=np.intp)
1943-
self._set_group_selection()
1942+
with _group_selection_context(self):
19441943

1945-
mask_left = np.in1d(self._cumcount_array(), nth_array)
1946-
mask_right = np.in1d(self._cumcount_array(ascending=False) + 1, -nth_array)
1947-
mask = mask_left | mask_right
1944+
mask_left = np.in1d(self._cumcount_array(), nth_array)
1945+
mask_right = np.in1d(
1946+
self._cumcount_array(ascending=False) + 1, -nth_array
1947+
)
1948+
mask = mask_left | mask_right
19481949

1949-
ids, _, _ = self.grouper.group_info
1950+
ids, _, _ = self.grouper.group_info
19501951

1951-
# Drop NA values in grouping
1952-
mask = mask & (ids != -1)
1952+
# Drop NA values in grouping
1953+
mask = mask & (ids != -1)
19531954

1954-
out = self._selected_obj[mask]
1955-
if not self.as_index:
1956-
return out
1955+
out = self._selected_obj[mask]
1956+
if not self.as_index:
1957+
return out
19571958

1958-
result_index = self.grouper.result_index
1959-
out.index = result_index[ids[mask]]
1959+
result_index = self.grouper.result_index
1960+
out.index = result_index[ids[mask]]
19601961

1961-
if not self.observed and isinstance(result_index, CategoricalIndex):
1962-
out = out.reindex(result_index)
1962+
if not self.observed and isinstance(result_index, CategoricalIndex):
1963+
out = out.reindex(result_index)
19631964

1964-
out = self._reindex_output(out)
1965-
return out.sort_index() if self.sort else out
1965+
out = self._reindex_output(out)
1966+
return out.sort_index() if self.sort else out
19661967

19671968
# dropna is truthy
19681969
if isinstance(n, valid_containers):

pandas/tests/groupby/aggregate/test_other.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -486,13 +486,13 @@ def test_agg_timezone_round_trip():
486486
assert ts == grouped.first()["B"].iloc[0]
487487

488488
# GH#27110 applying iloc should return a DataFrame
489-
assert ts == grouped.apply(lambda x: x.iloc[0]).iloc[0, 0]
489+
assert ts == grouped.apply(lambda x: x.iloc[0]).iloc[0, 1]
490490

491491
ts = df["B"].iloc[2]
492492
assert ts == grouped.last()["B"].iloc[0]
493493

494494
# GH#27110 applying iloc should return a DataFrame
495-
assert ts == grouped.apply(lambda x: x.iloc[-1]).iloc[0, 0]
495+
assert ts == grouped.apply(lambda x: x.iloc[-1]).iloc[0, 1]
496496

497497

498498
def test_sum_uint64_overflow():

pandas/tests/groupby/test_apply.py

+29
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,35 @@ def test_apply_with_timezones_aware():
10091009
tm.assert_frame_equal(result1, result2)
10101010

10111011

1012+
def test_apply_is_unchanged_when_other_methods_are_called_first(reduction_func):
1013+
# GH #34656
1014+
# GH #34271
1015+
df = DataFrame(
1016+
{
1017+
"a": [99, 99, 99, 88, 88, 88],
1018+
"b": [1, 2, 3, 4, 5, 6],
1019+
"c": [10, 20, 30, 40, 50, 60],
1020+
}
1021+
)
1022+
1023+
expected = pd.DataFrame(
1024+
{"a": [264, 297], "b": [15, 6], "c": [150, 60]},
1025+
index=pd.Index([88, 99], name="a"),
1026+
)
1027+
1028+
# Check output when no other methods are called before .apply()
1029+
grp = df.groupby(by="a")
1030+
result = grp.apply(sum)
1031+
tm.assert_frame_equal(result, expected)
1032+
1033+
# Check output when another method is called before .apply()
1034+
grp = df.groupby(by="a")
1035+
args = {"nth": [0], "corrwith": [df]}.get(reduction_func, [])
1036+
_ = getattr(grp, reduction_func)(*args)
1037+
result = grp.apply(sum)
1038+
tm.assert_frame_equal(result, expected)
1039+
1040+
10121041
def test_apply_with_date_in_multiindex_does_not_convert_to_timestamp():
10131042
# GH 29617
10141043

pandas/tests/groupby/test_grouping.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,15 @@ def test_grouper_creation_bug(self):
191191
result = g.sum()
192192
tm.assert_frame_equal(result, expected)
193193

194-
result = g.apply(lambda x: x.sum())
195-
tm.assert_frame_equal(result, expected)
196-
197194
g = df.groupby(pd.Grouper(key="A", axis=0))
198195
result = g.sum()
199196
tm.assert_frame_equal(result, expected)
200197

198+
result = g.apply(lambda x: x.sum())
199+
expected["A"] = [0, 2, 4]
200+
expected = expected.loc[:, ["A", "B"]]
201+
tm.assert_frame_equal(result, expected)
202+
201203
# GH14334
202204
# pd.Grouper(key=...) may be passed in a list
203205
df = DataFrame(

0 commit comments

Comments
 (0)