diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index 4a4e7dd6d15d7..3294e701cb1ef 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -164,6 +164,14 @@ Previously, negative arguments returned empty frames. df.groupby("A").nth(slice(1, -1)) df.groupby("A").nth([slice(None, 1), slice(-1, None)]) +:meth:`.GroupBy.nth` now accepts index notation. + +.. ipython:: python + + df.groupby("A").nth[1, -1] + df.groupby("A").nth[1:-1] + df.groupby("A").nth[:1, -1:] + .. _whatsnew_140.dict_tight: DataFrame.from_dict and DataFrame.to_dict have new ``'tight'`` option diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 8cd5712597fef..2876ec1cb5a0d 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -100,7 +100,10 @@ class providing the base-class of operations. numba_, ops, ) -from pandas.core.groupby.indexing import GroupByIndexingMixin +from pandas.core.groupby.indexing import ( + GroupByIndexingMixin, + GroupByNthSelector, +) from pandas.core.indexes.api import ( CategoricalIndex, Index, @@ -902,6 +905,15 @@ def __getattr__(self, attr: str): f"'{type(self).__name__}' object has no attribute '{attr}'" ) + def __getattribute__(self, attr: str): + # Intercept nth to allow both call and index + if attr == "nth": + return GroupByNthSelector(self) + elif attr == "nth_actual": + return super().__getattribute__("nth") + else: + return super().__getattribute__(attr) + @final def _make_wrapper(self, name: str) -> Callable: assert name in self._apply_allowlist @@ -2524,6 +2536,9 @@ def nth( """ Take the nth row from each group if n is an int, otherwise a subset of rows. + Can be either a call or an index. dropna is not available with index notation. + Index notation accepts a comma separated list of integers and slices. + If dropna, will take the nth non-null row, dropna is either 'all' or 'any'; this is equivalent to calling dropna(how=dropna) before the groupby. @@ -2535,6 +2550,7 @@ def nth( .. versionchanged:: 1.4.0 Added slice and lists containiing slices. + Added index notation. dropna : {'any', 'all', None}, default None Apply the specified dropna operation before counting which row is @@ -2580,6 +2596,22 @@ def nth( 1 2.0 2 3.0 + Index notation may also be used + + >>> g.nth[0, 1] + B + A + 1 NaN + 1 2.0 + 2 3.0 + 2 5.0 + >>> g.nth[:-1] + B + A + 1 NaN + 1 2.0 + 2 3.0 + Specifying `dropna` allows count ignoring ``NaN`` >>> g.nth(0, dropna='any') diff --git a/pandas/core/groupby/indexing.py b/pandas/core/groupby/indexing.py index 4b3bb6bc0aa50..f98bdf4b8be29 100644 --- a/pandas/core/groupby/indexing.py +++ b/pandas/core/groupby/indexing.py @@ -3,6 +3,7 @@ from typing import ( TYPE_CHECKING, Iterable, + Literal, cast, ) @@ -281,3 +282,22 @@ def __getitem__(self, arg: PositionalIndexer | tuple) -> DataFrame | Series: self.groupby_object._reset_group_selection() mask = self.groupby_object._make_mask_from_positional_indexer(arg) return self.groupby_object._mask_selected_obj(mask) + + +class GroupByNthSelector: + """ + Dynamically substituted for GroupBy.nth to enable both call and index + """ + + def __init__(self, groupby_object: groupby.GroupBy): + self.groupby_object = groupby_object + + def __call__( + self, + n: PositionalIndexer | tuple, + dropna: Literal["any", "all", None] = None, + ) -> DataFrame | Series: + return self.groupby_object.nth_actual(n, dropna) + + def __getitem__(self, n: PositionalIndexer | tuple) -> DataFrame | Series: + return self.groupby_object.nth_actual(n) diff --git a/pandas/tests/groupby/test_nth.py b/pandas/tests/groupby/test_nth.py index a5cb511763eee..8a5f972c22640 100644 --- a/pandas/tests/groupby/test_nth.py +++ b/pandas/tests/groupby/test_nth.py @@ -720,10 +720,23 @@ def test_groupby_last_first_nth_with_none(method, nulls_fixture): def test_slice(slice_test_df, slice_test_grouped, arg, expected_rows): # Test slices GH #42947 - result = slice_test_grouped.nth(arg) + result = slice_test_grouped.nth[arg] + equivalent = slice_test_grouped.nth(arg) expected = slice_test_df.iloc[expected_rows] tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(equivalent, expected) + + +def test_nth_indexed(slice_test_df, slice_test_grouped): + # Test index notation GH #44688 + + result = slice_test_grouped.nth[0, 1, -2:] + equivalent = slice_test_grouped.nth([0, 1, slice(-2, None)]) + expected = slice_test_df.iloc[[0, 1, 2, 3, 4, 6, 7]] + + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(equivalent, expected) def test_invalid_argument(slice_test_grouped): @@ -769,3 +782,30 @@ def test_groupby_nth_with_column_axis(): ) expected.columns.name = "y" tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "start, stop, expected_values, expected_columns", + [ + (None, None, [0, 1, 2, 3, 4], [5, 5, 5, 6, 6]), + (None, 1, [0, 3], [5, 6]), + (None, 9, [0, 1, 2, 3, 4], [5, 5, 5, 6, 6]), + (None, -1, [0, 1, 3], [5, 5, 6]), + (1, None, [1, 2, 4], [5, 5, 6]), + (1, -1, [1], [5]), + (-1, None, [2, 4], [5, 6]), + (-1, 2, [4], [6]), + ], +) +@pytest.mark.parametrize("method", ["call", "index"]) +def test_nth_slices_with_column_axis( + start, stop, expected_values, expected_columns, method +): + df = DataFrame([range(5)], columns=[list("ABCDE")]) + gb = df.groupby([5, 5, 5, 6, 6], axis=1) + result = { + "call": lambda start, stop: gb.nth(slice(start, stop)), + "index": lambda start, stop: gb.nth[start:stop], + }[method](start, stop) + expected = DataFrame([expected_values], columns=expected_columns) + tm.assert_frame_equal(result, expected)