Skip to content

Commit b159c75

Browse files
authored
BUG: Fix groupby nth with axis=1 (#43926)
1 parent 247703d commit b159c75

File tree

3 files changed

+57
-12
lines changed

3 files changed

+57
-12
lines changed

doc/source/whatsnew/v1.4.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,10 @@ Groupby/resample/rolling
506506
- Bug in :meth:`GroupBy.apply` with time-based :class:`Grouper` objects incorrectly raising ``ValueError`` in corner cases where the grouping vector contains a ``NaT`` (:issue:`43500`, :issue:`43515`)
507507
- Bug in :meth:`GroupBy.mean` failing with ``complex`` dtype (:issue:`43701`)
508508
- Fixed bug in :meth:`Series.rolling` and :meth:`DataFrame.rolling` not calculating window bounds correctly for the first row when ``center=True`` and index is decreasing (:issue:`43927`)
509+
- Bug in :meth:`GroupBy.nth` failing on ``axis=1`` (:issue:`43926`)
509510
- Fixed bug in :meth:`Series.rolling` and :meth:`DataFrame.rolling` not respecting right bound on centered datetime-like windows, if the index contain duplicates (:issue:`#3944`)
510511

512+
511513
Reshaping
512514
^^^^^^^^^
513515
- Improved error message when creating a :class:`DataFrame` column from a multi-dimensional :class:`numpy.ndarray` (:issue:`42463`)

pandas/core/groupby/groupby.py

+32-12
Original file line numberDiff line numberDiff line change
@@ -2446,7 +2446,7 @@ def backfill(self, limit=None):
24462446
@Substitution(see_also=_common_see_also)
24472447
def nth(
24482448
self, n: int | list[int], dropna: Literal["any", "all", None] = None
2449-
) -> DataFrame:
2449+
) -> NDFrameT:
24502450
"""
24512451
Take the nth row from each group if n is an int, or a subset of rows
24522452
if n is a list of ints.
@@ -2545,18 +2545,22 @@ def nth(
25452545
# Drop NA values in grouping
25462546
mask = mask & (ids != -1)
25472547

2548-
out = self._selected_obj[mask]
2548+
out = self._mask_selected_obj(mask)
2549+
25492550
if not self.as_index:
25502551
return out
25512552

25522553
result_index = self.grouper.result_index
2553-
out.index = result_index[ids[mask]]
2554+
if self.axis == 0:
2555+
out.index = result_index[ids[mask]]
2556+
if not self.observed and isinstance(result_index, CategoricalIndex):
2557+
out = out.reindex(result_index)
25542558

2555-
if not self.observed and isinstance(result_index, CategoricalIndex):
2556-
out = out.reindex(result_index)
2559+
out = self._reindex_output(out)
2560+
else:
2561+
out.columns = result_index[ids[mask]]
25572562

2558-
out = self._reindex_output(out)
2559-
return out.sort_index() if self.sort else out
2563+
return out.sort_index(axis=self.axis) if self.sort else out
25602564

25612565
# dropna is truthy
25622566
if isinstance(n, valid_containers):
@@ -2599,7 +2603,9 @@ def nth(
25992603
mutated=self.mutated,
26002604
)
26012605

2602-
grb = dropped.groupby(grouper, as_index=self.as_index, sort=self.sort)
2606+
grb = dropped.groupby(
2607+
grouper, as_index=self.as_index, sort=self.sort, axis=self.axis
2608+
)
26032609
sizes, result = grb.size(), grb.nth(n)
26042610
mask = (sizes < max_len)._values
26052611

@@ -3317,10 +3323,7 @@ def head(self, n=5):
33173323
"""
33183324
self._reset_group_selection()
33193325
mask = self._cumcount_array() < n
3320-
if self.axis == 0:
3321-
return self._selected_obj[mask]
3322-
else:
3323-
return self._selected_obj.iloc[:, mask]
3326+
return self._mask_selected_obj(mask)
33243327

33253328
@final
33263329
@Substitution(name="groupby")
@@ -3355,6 +3358,23 @@ def tail(self, n=5):
33553358
"""
33563359
self._reset_group_selection()
33573360
mask = self._cumcount_array(ascending=False) < n
3361+
return self._mask_selected_obj(mask)
3362+
3363+
@final
3364+
def _mask_selected_obj(self, mask: np.ndarray) -> NDFrameT:
3365+
"""
3366+
Return _selected_obj with mask applied to the correct axis.
3367+
3368+
Parameters
3369+
----------
3370+
mask : np.ndarray
3371+
Boolean mask to apply.
3372+
3373+
Returns
3374+
-------
3375+
Series or DataFrame
3376+
Filtered _selected_obj.
3377+
"""
33583378
if self.axis == 0:
33593379
return self._selected_obj[mask]
33603380
else:

pandas/tests/groupby/test_nth.py

+23
Original file line numberDiff line numberDiff line change
@@ -706,3 +706,26 @@ def test_groupby_last_first_nth_with_none(method, nulls_fixture):
706706
result = getattr(data, method)()
707707

708708
tm.assert_series_equal(result, expected)
709+
710+
711+
def test_groupby_nth_with_column_axis():
712+
# GH43926
713+
df = DataFrame(
714+
[
715+
[4, 5, 6],
716+
[8, 8, 7],
717+
],
718+
index=["z", "y"],
719+
columns=["C", "B", "A"],
720+
)
721+
result = df.groupby(df.iloc[1], axis=1).nth(0)
722+
expected = DataFrame(
723+
[
724+
[6, 4],
725+
[7, 8],
726+
],
727+
index=["z", "y"],
728+
columns=[7, 8],
729+
)
730+
expected.columns.name = "y"
731+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)