From 4609d11328f2fbd84b06ca2e420f89bb939044c2 Mon Sep 17 00:00:00 2001 From: John Zangwill Date: Fri, 8 Oct 2021 14:37:25 +0100 Subject: [PATCH 1/9] Fix nth with axis=1 and add minimal test --- pandas/core/groupby/groupby.py | 19 +++++++++++++------ pandas/tests/groupby/test_nth.py | 22 ++++++++++++++++++++++ 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 52bf44a0bb4ec..99d2582fc4031 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2545,18 +2545,25 @@ def nth( # Drop NA values in grouping mask = mask & (ids != -1) - out = self._selected_obj[mask] + if self.axis == 0: + out = self._selected_obj[mask] + else: + out = self._selected_obj.iloc[:, mask] + if not self.as_index: return out result_index = self.grouper.result_index - out.index = result_index[ids[mask]] + if self.axis == 0: + out.index = result_index[ids[mask]] + if not self.observed and isinstance(result_index, CategoricalIndex): + out = out.reindex(result_index) - if not self.observed and isinstance(result_index, CategoricalIndex): - out = out.reindex(result_index) + out = self._reindex_output(out) + else: + out.columns = result_index[ids[mask]] - out = self._reindex_output(out) - return out.sort_index() if self.sort else out + return out.sort_index(axis=self.axis) if self.sort else out # dropna is truthy if isinstance(n, valid_containers): diff --git a/pandas/tests/groupby/test_nth.py b/pandas/tests/groupby/test_nth.py index f0eef550b39ac..58eb1e0a4c3b5 100644 --- a/pandas/tests/groupby/test_nth.py +++ b/pandas/tests/groupby/test_nth.py @@ -706,3 +706,25 @@ def test_groupby_last_first_nth_with_none(method, nulls_fixture): result = getattr(data, method)() tm.assert_series_equal(result, expected) + + +def test_groupby_nth_with_column_axis(): + df = DataFrame( + [ + [4, 5, 6], + [8, 8, 7], + ], + index=["z", "y"], + columns=["C", "B", "A"], + ) + result = df.groupby(df.iloc[1], axis=1).nth(0) + expected = DataFrame( + [ + [6, 4], + [7, 8], + ], + index=["z", "y"], + columns=[7, 8], + ) + expected.columns.name = "y" + tm.assert_frame_equal(result, expected) From acde4ab5f3d9249b86ab1e65e8a82f62208923e7 Mon Sep 17 00:00:00 2001 From: John Zangwill Date: Fri, 8 Oct 2021 14:51:16 +0100 Subject: [PATCH 2/9] Add the PR number to the test --- pandas/tests/groupby/test_nth.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/tests/groupby/test_nth.py b/pandas/tests/groupby/test_nth.py index 58eb1e0a4c3b5..8742135da59e5 100644 --- a/pandas/tests/groupby/test_nth.py +++ b/pandas/tests/groupby/test_nth.py @@ -709,6 +709,7 @@ def test_groupby_last_first_nth_with_none(method, nulls_fixture): def test_groupby_nth_with_column_axis(): + # GH43926 df = DataFrame( [ [4, 5, 6], From 6b45d2fd68ad60e0fa64f5506ebc356fee9ca45f Mon Sep 17 00:00:00 2001 From: John Zangwill Date: Fri, 8 Oct 2021 15:21:09 +0100 Subject: [PATCH 3/9] Fix axis=1 nth with dropna --- pandas/core/groupby/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 99d2582fc4031..c8bc6c5a39f0e 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2606,7 +2606,7 @@ def nth( mutated=self.mutated, ) - grb = dropped.groupby(grouper, as_index=self.as_index, sort=self.sort) + grb = dropped.groupby(grouper, as_index=self.as_index, sort=self.sort, axis=self.axis) sizes, result = grb.size(), grb.nth(n) mask = (sizes < max_len)._values From 87ec227d4b7149acbcc1edc6843deb7513e5537b Mon Sep 17 00:00:00 2001 From: John Zangwill Date: Fri, 8 Oct 2021 15:36:23 +0100 Subject: [PATCH 4/9] pep8 line length --- pandas/core/groupby/groupby.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index c8bc6c5a39f0e..3db6fc3d7e835 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2606,7 +2606,9 @@ def nth( mutated=self.mutated, ) - grb = dropped.groupby(grouper, as_index=self.as_index, sort=self.sort, axis=self.axis) + grb = dropped.groupby( + grouper, as_index=self.as_index, sort=self.sort, axis=self.axis + ) sizes, result = grb.size(), grb.nth(n) mask = (sizes < max_len)._values From 888b079e545783fc95b88d87bb6faa2e4c2cabba Mon Sep 17 00:00:00 2001 From: John Zangwill Date: Sat, 9 Oct 2021 13:49:52 +0100 Subject: [PATCH 5/9] factor out _mask_selected_obj --- pandas/core/groupby/groupby.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 3db6fc3d7e835..4b16b1e0519b5 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2446,7 +2446,7 @@ def backfill(self, limit=None): @Substitution(see_also=_common_see_also) def nth( self, n: int | list[int], dropna: Literal["any", "all", None] = None - ) -> DataFrame: + ) -> DataFrame | Series: """ Take the nth row from each group if n is an int, or a subset of rows if n is a list of ints. @@ -2545,10 +2545,7 @@ def nth( # Drop NA values in grouping mask = mask & (ids != -1) - if self.axis == 0: - out = self._selected_obj[mask] - else: - out = self._selected_obj.iloc[:, mask] + out = self._mask_selected_obj(mask) if not self.as_index: return out @@ -3326,10 +3323,7 @@ def head(self, n=5): """ self._reset_group_selection() mask = self._cumcount_array() < n - if self.axis == 0: - return self._selected_obj[mask] - else: - return self._selected_obj.iloc[:, mask] + return self._mask_selected_obj(mask) @final @Substitution(name="groupby") @@ -3364,6 +3358,10 @@ def tail(self, n=5): """ self._reset_group_selection() mask = self._cumcount_array(ascending=False) < n + return self._mask_selected_obj(mask) + + @final + def _mask_selected_obj(self, mask: np.ndarray) -> DataFrame | Series: if self.axis == 0: return self._selected_obj[mask] else: From 5d9d26a8bcee84985d2e6c4e3710fe34256846d4 Mon Sep 17 00:00:00 2001 From: John Zangwill Date: Sat, 9 Oct 2021 14:17:40 +0100 Subject: [PATCH 6/9] Resolve output types --- pandas/core/groupby/groupby.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 4b16b1e0519b5..98c77f5243aa7 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2446,7 +2446,7 @@ def backfill(self, limit=None): @Substitution(see_also=_common_see_also) def nth( self, n: int | list[int], dropna: Literal["any", "all", None] = None - ) -> DataFrame | Series: + ) -> NDFrameT: """ Take the nth row from each group if n is an int, or a subset of rows if n is a list of ints. @@ -3361,7 +3361,7 @@ def tail(self, n=5): return self._mask_selected_obj(mask) @final - def _mask_selected_obj(self, mask: np.ndarray) -> DataFrame | Series: + def _mask_selected_obj(self, mask: np.ndarray) -> NDFrameT: if self.axis == 0: return self._selected_obj[mask] else: From 41bc34d72ee92748e1da8d2c2300a40b8b98db20 Mon Sep 17 00:00:00 2001 From: John Zangwill Date: Sat, 9 Oct 2021 15:20:25 +0100 Subject: [PATCH 7/9] Add to whatsnew --- doc/source/whatsnew/v1.4.0.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index 722d0dcc10041..950e8c66d793c 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -502,6 +502,7 @@ Groupby/resample/rolling - Bug in :meth:`DataFrame.groupby.rolling` when specifying ``on`` and calling ``__getitem__`` would subsequently return incorrect results (:issue:`43355`) - Bug in :meth:`GroupBy.apply` with time-based :class:`Grouper` objects incorrectly raising ``ValueError`` in corner cases where the grouping vector contains a ``NaT`` (:issue:`43500`, :issue:`43515`) - Bug in :meth:`GroupBy.mean` failing with ``complex`` dtype (:issue:`43701`) +- Bug in :meth:`GroupBy.nth` failing on column groupings (``axis=1``) Reshaping ^^^^^^^^^ From 6e60af8d62c4a9e3c36dac55c83e703278129ed1 Mon Sep 17 00:00:00 2001 From: John Zangwill Date: Sun, 10 Oct 2021 23:19:35 +0100 Subject: [PATCH 8/9] Changed whatsnew --- doc/source/whatsnew/v1.4.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index 4a3e16f30c9b0..f9af05db4affb 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -503,7 +503,7 @@ Groupby/resample/rolling - Bug in :meth:`GroupBy.apply` with time-based :class:`Grouper` objects incorrectly raising ``ValueError`` in corner cases where the grouping vector contains a ``NaT`` (:issue:`43500`, :issue:`43515`) - Bug in :meth:`GroupBy.mean` failing with ``complex`` dtype (:issue:`43701`) - Fixed bug in :meth:`Series.rolling` and :meth:`DataFrame.rolling` not calculating window bounds correctly for the first row when ``center=True`` and index is decreasing (:issue:`43927`) -- Bug in :meth:`GroupBy.nth` failing on column groupings (``axis=1``) +- Bug in :meth:`GroupBy.nth` failing on ``axis=1`` Reshaping ^^^^^^^^^ From f984a9cabd6c8ef12ce3c8cdb1900d0a81719254 Mon Sep 17 00:00:00 2001 From: John Zangwill Date: Wed, 13 Oct 2021 08:59:37 +0100 Subject: [PATCH 9/9] Add docstring. Add GH number to whatsnew. --- doc/source/whatsnew/v1.4.0.rst | 2 +- pandas/core/groupby/groupby.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index 6e75a4f64e688..c5a22d8766a96 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -506,7 +506,7 @@ Groupby/resample/rolling - Bug in :meth:`GroupBy.apply` with time-based :class:`Grouper` objects incorrectly raising ``ValueError`` in corner cases where the grouping vector contains a ``NaT`` (:issue:`43500`, :issue:`43515`) - Bug in :meth:`GroupBy.mean` failing with ``complex`` dtype (:issue:`43701`) - Fixed bug in :meth:`Series.rolling` and :meth:`DataFrame.rolling` not calculating window bounds correctly for the first row when ``center=True`` and index is decreasing (:issue:`43927`) -- Bug in :meth:`GroupBy.nth` failing on ``axis=1`` +- Bug in :meth:`GroupBy.nth` failing on ``axis=1`` (:issue:`43926`) - Fixed bug in :meth:`Series.rolling` and :meth:`DataFrame.rolling` not respecting right bound on centered datetime-like windows, if the index contain duplicates (:issue:`#3944`) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 98c77f5243aa7..9ca05e05fc09a 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -3362,6 +3362,19 @@ def tail(self, n=5): @final def _mask_selected_obj(self, mask: np.ndarray) -> NDFrameT: + """ + Return _selected_obj with mask applied to the correct axis. + + Parameters + ---------- + mask : np.ndarray + Boolean mask to apply. + + Returns + ------- + Series or DataFrame + Filtered _selected_obj. + """ if self.axis == 0: return self._selected_obj[mask] else: