Skip to content

Commit 3e7c2f6

Browse files
authored
ENH: Add indexing syntax to GroupBy.nth() (#44688)
1 parent 5008ece commit 3e7c2f6

File tree

4 files changed

+102
-2
lines changed

4 files changed

+102
-2
lines changed

doc/source/whatsnew/v1.4.0.rst

+8
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,14 @@ Previously, negative arguments returned empty frames.
164164
df.groupby("A").nth(slice(1, -1))
165165
df.groupby("A").nth([slice(None, 1), slice(-1, None)])
166166
167+
:meth:`.GroupBy.nth` now accepts index notation.
168+
169+
.. ipython:: python
170+
171+
df.groupby("A").nth[1, -1]
172+
df.groupby("A").nth[1:-1]
173+
df.groupby("A").nth[:1, -1:]
174+
167175
.. _whatsnew_140.dict_tight:
168176

169177
DataFrame.from_dict and DataFrame.to_dict have new ``'tight'`` option

pandas/core/groupby/groupby.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@ class providing the base-class of operations.
100100
numba_,
101101
ops,
102102
)
103-
from pandas.core.groupby.indexing import GroupByIndexingMixin
103+
from pandas.core.groupby.indexing import (
104+
GroupByIndexingMixin,
105+
GroupByNthSelector,
106+
)
104107
from pandas.core.indexes.api import (
105108
CategoricalIndex,
106109
Index,
@@ -902,6 +905,15 @@ def __getattr__(self, attr: str):
902905
f"'{type(self).__name__}' object has no attribute '{attr}'"
903906
)
904907

908+
def __getattribute__(self, attr: str):
909+
# Intercept nth to allow both call and index
910+
if attr == "nth":
911+
return GroupByNthSelector(self)
912+
elif attr == "nth_actual":
913+
return super().__getattribute__("nth")
914+
else:
915+
return super().__getattribute__(attr)
916+
905917
@final
906918
def _make_wrapper(self, name: str) -> Callable:
907919
assert name in self._apply_allowlist
@@ -2524,6 +2536,9 @@ def nth(
25242536
"""
25252537
Take the nth row from each group if n is an int, otherwise a subset of rows.
25262538
2539+
Can be either a call or an index. dropna is not available with index notation.
2540+
Index notation accepts a comma separated list of integers and slices.
2541+
25272542
If dropna, will take the nth non-null row, dropna is either
25282543
'all' or 'any'; this is equivalent to calling dropna(how=dropna)
25292544
before the groupby.
@@ -2535,6 +2550,7 @@ def nth(
25352550
25362551
.. versionchanged:: 1.4.0
25372552
Added slice and lists containiing slices.
2553+
Added index notation.
25382554
25392555
dropna : {'any', 'all', None}, default None
25402556
Apply the specified dropna operation before counting which row is
@@ -2580,6 +2596,22 @@ def nth(
25802596
1 2.0
25812597
2 3.0
25822598
2599+
Index notation may also be used
2600+
2601+
>>> g.nth[0, 1]
2602+
B
2603+
A
2604+
1 NaN
2605+
1 2.0
2606+
2 3.0
2607+
2 5.0
2608+
>>> g.nth[:-1]
2609+
B
2610+
A
2611+
1 NaN
2612+
1 2.0
2613+
2 3.0
2614+
25832615
Specifying `dropna` allows count ignoring ``NaN``
25842616
25852617
>>> g.nth(0, dropna='any')

pandas/core/groupby/indexing.py

+20
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import (
44
TYPE_CHECKING,
55
Iterable,
6+
Literal,
67
cast,
78
)
89

@@ -281,3 +282,22 @@ def __getitem__(self, arg: PositionalIndexer | tuple) -> DataFrame | Series:
281282
self.groupby_object._reset_group_selection()
282283
mask = self.groupby_object._make_mask_from_positional_indexer(arg)
283284
return self.groupby_object._mask_selected_obj(mask)
285+
286+
287+
class GroupByNthSelector:
288+
"""
289+
Dynamically substituted for GroupBy.nth to enable both call and index
290+
"""
291+
292+
def __init__(self, groupby_object: groupby.GroupBy):
293+
self.groupby_object = groupby_object
294+
295+
def __call__(
296+
self,
297+
n: PositionalIndexer | tuple,
298+
dropna: Literal["any", "all", None] = None,
299+
) -> DataFrame | Series:
300+
return self.groupby_object.nth_actual(n, dropna)
301+
302+
def __getitem__(self, n: PositionalIndexer | tuple) -> DataFrame | Series:
303+
return self.groupby_object.nth_actual(n)

pandas/tests/groupby/test_nth.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -720,10 +720,23 @@ def test_groupby_last_first_nth_with_none(method, nulls_fixture):
720720
def test_slice(slice_test_df, slice_test_grouped, arg, expected_rows):
721721
# Test slices GH #42947
722722

723-
result = slice_test_grouped.nth(arg)
723+
result = slice_test_grouped.nth[arg]
724+
equivalent = slice_test_grouped.nth(arg)
724725
expected = slice_test_df.iloc[expected_rows]
725726

726727
tm.assert_frame_equal(result, expected)
728+
tm.assert_frame_equal(equivalent, expected)
729+
730+
731+
def test_nth_indexed(slice_test_df, slice_test_grouped):
732+
# Test index notation GH #44688
733+
734+
result = slice_test_grouped.nth[0, 1, -2:]
735+
equivalent = slice_test_grouped.nth([0, 1, slice(-2, None)])
736+
expected = slice_test_df.iloc[[0, 1, 2, 3, 4, 6, 7]]
737+
738+
tm.assert_frame_equal(result, expected)
739+
tm.assert_frame_equal(equivalent, expected)
727740

728741

729742
def test_invalid_argument(slice_test_grouped):
@@ -769,3 +782,30 @@ def test_groupby_nth_with_column_axis():
769782
)
770783
expected.columns.name = "y"
771784
tm.assert_frame_equal(result, expected)
785+
786+
787+
@pytest.mark.parametrize(
788+
"start, stop, expected_values, expected_columns",
789+
[
790+
(None, None, [0, 1, 2, 3, 4], [5, 5, 5, 6, 6]),
791+
(None, 1, [0, 3], [5, 6]),
792+
(None, 9, [0, 1, 2, 3, 4], [5, 5, 5, 6, 6]),
793+
(None, -1, [0, 1, 3], [5, 5, 6]),
794+
(1, None, [1, 2, 4], [5, 5, 6]),
795+
(1, -1, [1], [5]),
796+
(-1, None, [2, 4], [5, 6]),
797+
(-1, 2, [4], [6]),
798+
],
799+
)
800+
@pytest.mark.parametrize("method", ["call", "index"])
801+
def test_nth_slices_with_column_axis(
802+
start, stop, expected_values, expected_columns, method
803+
):
804+
df = DataFrame([range(5)], columns=[list("ABCDE")])
805+
gb = df.groupby([5, 5, 5, 6, 6], axis=1)
806+
result = {
807+
"call": lambda start, stop: gb.nth(slice(start, stop)),
808+
"index": lambda start, stop: gb.nth[start:stop],
809+
}[method](start, stop)
810+
expected = DataFrame([expected_values], columns=expected_columns)
811+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)