diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index e9101c9ff1f12..c5a22d8766a96 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -506,8 +506,10 @@ Groupby/resample/rolling - Bug in :meth:`GroupBy.apply` with time-based :class:`Grouper` objects incorrectly raising ``ValueError`` in corner cases where the grouping vector contains a ``NaT`` (:issue:`43500`, :issue:`43515`) - Bug in :meth:`GroupBy.mean` failing with ``complex`` dtype (:issue:`43701`) - Fixed bug in :meth:`Series.rolling` and :meth:`DataFrame.rolling` not calculating window bounds correctly for the first row when ``center=True`` and index is decreasing (:issue:`43927`) +- Bug in :meth:`GroupBy.nth` failing on ``axis=1`` (:issue:`43926`) - Fixed bug in :meth:`Series.rolling` and :meth:`DataFrame.rolling` not respecting right bound on centered datetime-like windows, if the index contain duplicates (:issue:`#3944`) + Reshaping ^^^^^^^^^ - Improved error message when creating a :class:`DataFrame` column from a multi-dimensional :class:`numpy.ndarray` (:issue:`42463`) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 52bf44a0bb4ec..9ca05e05fc09a 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2446,7 +2446,7 @@ def backfill(self, limit=None): @Substitution(see_also=_common_see_also) def nth( self, n: int | list[int], dropna: Literal["any", "all", None] = None - ) -> DataFrame: + ) -> NDFrameT: """ Take the nth row from each group if n is an int, or a subset of rows if n is a list of ints. @@ -2545,18 +2545,22 @@ def nth( # Drop NA values in grouping mask = mask & (ids != -1) - out = self._selected_obj[mask] + out = self._mask_selected_obj(mask) + if not self.as_index: return out result_index = self.grouper.result_index - out.index = result_index[ids[mask]] + if self.axis == 0: + out.index = result_index[ids[mask]] + if not self.observed and isinstance(result_index, CategoricalIndex): + out = out.reindex(result_index) - if not self.observed and isinstance(result_index, CategoricalIndex): - out = out.reindex(result_index) + out = self._reindex_output(out) + else: + out.columns = result_index[ids[mask]] - out = self._reindex_output(out) - return out.sort_index() if self.sort else out + return out.sort_index(axis=self.axis) if self.sort else out # dropna is truthy if isinstance(n, valid_containers): @@ -2599,7 +2603,9 @@ def nth( mutated=self.mutated, ) - grb = dropped.groupby(grouper, as_index=self.as_index, sort=self.sort) + grb = dropped.groupby( + grouper, as_index=self.as_index, sort=self.sort, axis=self.axis + ) sizes, result = grb.size(), grb.nth(n) mask = (sizes < max_len)._values @@ -3317,10 +3323,7 @@ def head(self, n=5): """ self._reset_group_selection() mask = self._cumcount_array() < n - if self.axis == 0: - return self._selected_obj[mask] - else: - return self._selected_obj.iloc[:, mask] + return self._mask_selected_obj(mask) @final @Substitution(name="groupby") @@ -3355,6 +3358,23 @@ def tail(self, n=5): """ self._reset_group_selection() mask = self._cumcount_array(ascending=False) < n + return self._mask_selected_obj(mask) + + @final + def _mask_selected_obj(self, mask: np.ndarray) -> NDFrameT: + """ + Return _selected_obj with mask applied to the correct axis. + + Parameters + ---------- + mask : np.ndarray + Boolean mask to apply. + + Returns + ------- + Series or DataFrame + Filtered _selected_obj. + """ if self.axis == 0: return self._selected_obj[mask] else: diff --git a/pandas/tests/groupby/test_nth.py b/pandas/tests/groupby/test_nth.py index f0eef550b39ac..8742135da59e5 100644 --- a/pandas/tests/groupby/test_nth.py +++ b/pandas/tests/groupby/test_nth.py @@ -706,3 +706,26 @@ def test_groupby_last_first_nth_with_none(method, nulls_fixture): result = getattr(data, method)() tm.assert_series_equal(result, expected) + + +def test_groupby_nth_with_column_axis(): + # GH43926 + df = DataFrame( + [ + [4, 5, 6], + [8, 8, 7], + ], + index=["z", "y"], + columns=["C", "B", "A"], + ) + result = df.groupby(df.iloc[1], axis=1).nth(0) + expected = DataFrame( + [ + [6, 4], + [7, 8], + ], + index=["z", "y"], + columns=[7, 8], + ) + expected.columns.name = "y" + tm.assert_frame_equal(result, expected)