diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index 10dfd8406b8ce..c633583ebc2cd 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -155,7 +155,7 @@ Groupby/resample/rolling - Bug in :meth:`DataFrameGroupBy.apply` that would some times throw an erroneous ``ValueError`` if the grouping axis had duplicate entries (:issue:`16646`) - - - +- Bug in :meth:`DataFrameGroupBy.apply` where a non-nuisance grouping column would be dropped from the output columns if another groupby method was called before ``.apply()`` (:issue:`34656`) Reshaping ^^^^^^^^^ diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 6c8a780859939..a3585091feb8a 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -736,13 +736,12 @@ def pipe(self, func, *args, **kwargs): def _make_wrapper(self, name): assert name in self._apply_allowlist - self._set_group_selection() - - # need to setup the selection - # as are not passed directly but in the grouper - f = getattr(self._obj_with_exclusions, name) - if not isinstance(f, types.MethodType): - return self.apply(lambda self: getattr(self, name)) + with _group_selection_context(self): + # need to setup the selection + # as are not passed directly but in the grouper + f = getattr(self._obj_with_exclusions, name) + if not isinstance(f, types.MethodType): + return self.apply(lambda self: getattr(self, name)) f = getattr(type(self._obj_with_exclusions), name) sig = inspect.signature(f) @@ -992,28 +991,28 @@ def _agg_general( alias: str, npfunc: Callable, ): - self._set_group_selection() - - # try a cython aggregation if we can - try: - return self._cython_agg_general( - how=alias, alt=npfunc, numeric_only=numeric_only, min_count=min_count, - ) - except DataError: - pass - except NotImplementedError as err: - if "function is not implemented for this dtype" in str( - err - ) or "category dtype not supported" in str(err): - # raised in _get_cython_function, in some cases can - # be trimmed by implementing cython funcs for more dtypes + with _group_selection_context(self): + # try a cython aggregation if we can + try: + return self._cython_agg_general( + how=alias, + alt=npfunc, + numeric_only=numeric_only, + min_count=min_count, + ) + except DataError: pass - else: - raise - - # apply a non-cython aggregation - result = self.aggregate(lambda x: npfunc(x, axis=self.axis)) - return result + except NotImplementedError as err: + if "function is not implemented for this dtype" in str( + err + ) or "category dtype not supported" in str(err): + # raised in _get_cython_function, in some cases can + # be trimmed by implementing cython funcs for more dtypes + pass + + # apply a non-cython aggregation + result = self.aggregate(lambda x: npfunc(x, axis=self.axis)) + return result def _cython_agg_general( self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1 @@ -1930,29 +1929,31 @@ def nth(self, n: Union[int, List[int]], dropna: Optional[str] = None) -> DataFra nth_values = list(set(n)) nth_array = np.array(nth_values, dtype=np.intp) - self._set_group_selection() + with _group_selection_context(self): - mask_left = np.in1d(self._cumcount_array(), nth_array) - mask_right = np.in1d(self._cumcount_array(ascending=False) + 1, -nth_array) - mask = mask_left | mask_right + mask_left = np.in1d(self._cumcount_array(), nth_array) + mask_right = np.in1d( + self._cumcount_array(ascending=False) + 1, -nth_array + ) + mask = mask_left | mask_right - ids, _, _ = self.grouper.group_info + ids, _, _ = self.grouper.group_info - # Drop NA values in grouping - mask = mask & (ids != -1) + # Drop NA values in grouping + mask = mask & (ids != -1) - out = self._selected_obj[mask] - if not self.as_index: - return out + out = self._selected_obj[mask] + if not self.as_index: + return out - result_index = self.grouper.result_index - out.index = result_index[ids[mask]] + result_index = self.grouper.result_index + out.index = result_index[ids[mask]] - if not self.observed and isinstance(result_index, CategoricalIndex): - out = out.reindex(result_index) + if not self.observed and isinstance(result_index, CategoricalIndex): + out = out.reindex(result_index) - out = self._reindex_output(out) - return out.sort_index() if self.sort else out + out = self._reindex_output(out) + return out.sort_index() if self.sort else out # dropna is truthy if isinstance(n, valid_containers): diff --git a/pandas/tests/groupby/aggregate/test_other.py b/pandas/tests/groupby/aggregate/test_other.py index 264cf40dc6984..e8cd6017a117c 100644 --- a/pandas/tests/groupby/aggregate/test_other.py +++ b/pandas/tests/groupby/aggregate/test_other.py @@ -486,13 +486,13 @@ def test_agg_timezone_round_trip(): assert ts == grouped.first()["B"].iloc[0] # GH#27110 applying iloc should return a DataFrame - assert ts == grouped.apply(lambda x: x.iloc[0]).iloc[0, 0] + assert ts == grouped.apply(lambda x: x.iloc[0]).iloc[0, 1] ts = df["B"].iloc[2] assert ts == grouped.last()["B"].iloc[0] # GH#27110 applying iloc should return a DataFrame - assert ts == grouped.apply(lambda x: x.iloc[-1]).iloc[0, 0] + assert ts == grouped.apply(lambda x: x.iloc[-1]).iloc[0, 1] def test_sum_uint64_overflow(): diff --git a/pandas/tests/groupby/test_apply.py b/pandas/tests/groupby/test_apply.py index 665cd12225ad7..ee38722ffb8ce 100644 --- a/pandas/tests/groupby/test_apply.py +++ b/pandas/tests/groupby/test_apply.py @@ -1009,6 +1009,35 @@ def test_apply_with_timezones_aware(): tm.assert_frame_equal(result1, result2) +def test_apply_is_unchanged_when_other_methods_are_called_first(reduction_func): + # GH #34656 + # GH #34271 + df = DataFrame( + { + "a": [99, 99, 99, 88, 88, 88], + "b": [1, 2, 3, 4, 5, 6], + "c": [10, 20, 30, 40, 50, 60], + } + ) + + expected = pd.DataFrame( + {"a": [264, 297], "b": [15, 6], "c": [150, 60]}, + index=pd.Index([88, 99], name="a"), + ) + + # Check output when no other methods are called before .apply() + grp = df.groupby(by="a") + result = grp.apply(sum) + tm.assert_frame_equal(result, expected) + + # Check output when another method is called before .apply() + grp = df.groupby(by="a") + args = {"nth": [0], "corrwith": [df]}.get(reduction_func, []) + _ = getattr(grp, reduction_func)(*args) + result = grp.apply(sum) + tm.assert_frame_equal(result, expected) + + def test_apply_with_date_in_multiindex_does_not_convert_to_timestamp(): # GH 29617 diff --git a/pandas/tests/groupby/test_grouping.py b/pandas/tests/groupby/test_grouping.py index efcd22f9c0c82..40b4ce46e550b 100644 --- a/pandas/tests/groupby/test_grouping.py +++ b/pandas/tests/groupby/test_grouping.py @@ -191,13 +191,15 @@ def test_grouper_creation_bug(self): result = g.sum() tm.assert_frame_equal(result, expected) - result = g.apply(lambda x: x.sum()) - tm.assert_frame_equal(result, expected) - g = df.groupby(pd.Grouper(key="A", axis=0)) result = g.sum() tm.assert_frame_equal(result, expected) + result = g.apply(lambda x: x.sum()) + expected["A"] = [0, 2, 4] + expected = expected.loc[:, ["A", "B"]] + tm.assert_frame_equal(result, expected) + # GH14334 # pd.Grouper(key=...) may be passed in a list df = DataFrame(