Skip to content

Commit 33a53f1

Browse files
authored
ENH: A new GroupBy method to slice rows preserving index and order (#42947)
1 parent 62d7a24 commit 33a53f1

File tree

6 files changed

+709
-43
lines changed

6 files changed

+709
-43
lines changed

doc/source/whatsnew/v1.4.0.rst

+24
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,30 @@ Example:
110110
111111
s.rolling(3).rank(method="max")
112112
113+
.. _whatsnew_140.enhancements.groupby_indexing:
114+
115+
Groupby positional indexing
116+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
117+
118+
It is now possible to specify positional ranges relative to the ends of each group.
119+
120+
Negative arguments for :meth:`.GroupBy.head` and :meth:`.GroupBy.tail` now work correctly and result in ranges relative to the end and start of each group, respectively.
121+
Previously, negative arguments returned empty frames.
122+
123+
.. ipython:: python
124+
125+
df = pd.DataFrame([["g", "g0"], ["g", "g1"], ["g", "g2"], ["g", "g3"],
126+
["h", "h0"], ["h", "h1"]], columns=["A", "B"])
127+
df.groupby("A").head(-1)
128+
129+
130+
:meth:`.GroupBy.nth` now accepts a slice or list of integers and slices.
131+
132+
.. ipython:: python
133+
134+
df.groupby("A").nth(slice(1, -1))
135+
df.groupby("A").nth([slice(None, 1), slice(-1, None)])
136+
113137
.. _whatsnew_140.enhancements.other:
114138

115139
Other enhancements

pandas/core/groupby/groupby.py

+49-38
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class providing the base-class of operations.
4646
ArrayLike,
4747
IndexLabel,
4848
NDFrameT,
49+
PositionalIndexer,
4950
RandomState,
5051
Scalar,
5152
T,
@@ -65,6 +66,7 @@ class providing the base-class of operations.
6566
is_bool_dtype,
6667
is_datetime64_dtype,
6768
is_float_dtype,
69+
is_integer,
6870
is_integer_dtype,
6971
is_numeric_dtype,
7072
is_object_dtype,
@@ -97,6 +99,7 @@ class providing the base-class of operations.
9799
numba_,
98100
ops,
99101
)
102+
from pandas.core.groupby.indexing import GroupByIndexingMixin
100103
from pandas.core.indexes.api import (
101104
CategoricalIndex,
102105
Index,
@@ -555,7 +558,7 @@ def f(self):
555558
]
556559

557560

558-
class BaseGroupBy(PandasObject, SelectionMixin[NDFrameT]):
561+
class BaseGroupBy(PandasObject, SelectionMixin[NDFrameT], GroupByIndexingMixin):
559562
_group_selection: IndexLabel | None = None
560563
_apply_allowlist: frozenset[str] = frozenset()
561564
_hidden_attrs = PandasObject._hidden_attrs | {
@@ -2445,23 +2448,28 @@ def backfill(self, limit=None):
24452448
@Substitution(name="groupby")
24462449
@Substitution(see_also=_common_see_also)
24472450
def nth(
2448-
self, n: int | list[int], dropna: Literal["any", "all", None] = None
2451+
self,
2452+
n: PositionalIndexer | tuple,
2453+
dropna: Literal["any", "all", None] = None,
24492454
) -> NDFrameT:
24502455
"""
2451-
Take the nth row from each group if n is an int, or a subset of rows
2452-
if n is a list of ints.
2456+
Take the nth row from each group if n is an int, otherwise a subset of rows.
24532457
24542458
If dropna, will take the nth non-null row, dropna is either
24552459
'all' or 'any'; this is equivalent to calling dropna(how=dropna)
24562460
before the groupby.
24572461
24582462
Parameters
24592463
----------
2460-
n : int or list of ints
2461-
A single nth value for the row or a list of nth values.
2464+
n : int, slice or list of ints and slices
2465+
A single nth value for the row or a list of nth values or slices.
2466+
2467+
.. versionchanged:: 1.4.0
2468+
Added slice and lists containiing slices.
2469+
24622470
dropna : {'any', 'all', None}, default None
24632471
Apply the specified dropna operation before counting which row is
2464-
the nth row.
2472+
the nth row. Only supported if n is an int.
24652473
24662474
Returns
24672475
-------
@@ -2496,6 +2504,12 @@ def nth(
24962504
1 2.0
24972505
2 3.0
24982506
2 5.0
2507+
>>> g.nth(slice(None, -1))
2508+
B
2509+
A
2510+
1 NaN
2511+
1 2.0
2512+
2 3.0
24992513
25002514
Specifying `dropna` allows count ignoring ``NaN``
25012515
@@ -2520,33 +2534,16 @@ def nth(
25202534
1 1 2.0
25212535
4 2 5.0
25222536
"""
2523-
valid_containers = (set, list, tuple)
2524-
if not isinstance(n, (valid_containers, int)):
2525-
raise TypeError("n needs to be an int or a list/set/tuple of ints")
2526-
25272537
if not dropna:
2528-
2529-
if isinstance(n, int):
2530-
nth_values = [n]
2531-
elif isinstance(n, valid_containers):
2532-
nth_values = list(set(n))
2533-
2534-
nth_array = np.array(nth_values, dtype=np.intp)
25352538
with self._group_selection_context():
2536-
2537-
mask_left = np.in1d(self._cumcount_array(), nth_array)
2538-
mask_right = np.in1d(
2539-
self._cumcount_array(ascending=False) + 1, -nth_array
2540-
)
2541-
mask = mask_left | mask_right
2539+
mask = self._make_mask_from_positional_indexer(n)
25422540

25432541
ids, _, _ = self.grouper.group_info
25442542

25452543
# Drop NA values in grouping
25462544
mask = mask & (ids != -1)
25472545

25482546
out = self._mask_selected_obj(mask)
2549-
25502547
if not self.as_index:
25512548
return out
25522549

@@ -2563,19 +2560,20 @@ def nth(
25632560
return out.sort_index(axis=self.axis) if self.sort else out
25642561

25652562
# dropna is truthy
2566-
if isinstance(n, valid_containers):
2567-
raise ValueError("dropna option with a list of nth values is not supported")
2563+
if not is_integer(n):
2564+
raise ValueError("dropna option only supported for an integer argument")
25682565

25692566
if dropna not in ["any", "all"]:
25702567
# Note: when agg-ing picker doesn't raise this, just returns NaN
25712568
raise ValueError(
2572-
"For a DataFrame groupby, dropna must be "
2569+
"For a DataFrame or Series groupby.nth, dropna must be "
25732570
"either None, 'any' or 'all', "
25742571
f"(was passed {dropna})."
25752572
)
25762573

25772574
# old behaviour, but with all and any support for DataFrames.
25782575
# modified in GH 7559 to have better perf
2576+
n = cast(int, n)
25792577
max_len = n if n >= 0 else -1 - n
25802578
dropped = self.obj.dropna(how=dropna, axis=self.axis)
25812579

@@ -3301,11 +3299,16 @@ def head(self, n=5):
33013299
from the original DataFrame with original index and order preserved
33023300
(``as_index`` flag is ignored).
33033301
3304-
Does not work for negative values of `n`.
3302+
Parameters
3303+
----------
3304+
n : int
3305+
If positive: number of entries to include from start of each group.
3306+
If negative: number of entries to exclude from end of each group.
33053307
33063308
Returns
33073309
-------
33083310
Series or DataFrame
3311+
Subset of original Series or DataFrame as determined by n.
33093312
%(see_also)s
33103313
Examples
33113314
--------
@@ -3317,12 +3320,11 @@ def head(self, n=5):
33173320
0 1 2
33183321
2 5 6
33193322
>>> df.groupby('A').head(-1)
3320-
Empty DataFrame
3321-
Columns: [A, B]
3322-
Index: []
3323+
A B
3324+
0 1 2
33233325
"""
33243326
self._reset_group_selection()
3325-
mask = self._cumcount_array() < n
3327+
mask = self._make_mask_from_positional_indexer(slice(None, n))
33263328
return self._mask_selected_obj(mask)
33273329

33283330
@final
@@ -3336,11 +3338,16 @@ def tail(self, n=5):
33363338
from the original DataFrame with original index and order preserved
33373339
(``as_index`` flag is ignored).
33383340
3339-
Does not work for negative values of `n`.
3341+
Parameters
3342+
----------
3343+
n : int
3344+
If positive: number of entries to include from end of each group.
3345+
If negative: number of entries to exclude from start of each group.
33403346
33413347
Returns
33423348
-------
33433349
Series or DataFrame
3350+
Subset of original Series or DataFrame as determined by n.
33443351
%(see_also)s
33453352
Examples
33463353
--------
@@ -3352,12 +3359,16 @@ def tail(self, n=5):
33523359
1 a 2
33533360
3 b 2
33543361
>>> df.groupby('A').tail(-1)
3355-
Empty DataFrame
3356-
Columns: [A, B]
3357-
Index: []
3362+
A B
3363+
1 a 2
3364+
3 b 2
33583365
"""
33593366
self._reset_group_selection()
3360-
mask = self._cumcount_array(ascending=False) < n
3367+
if n:
3368+
mask = self._make_mask_from_positional_indexer(slice(-n, None))
3369+
else:
3370+
mask = self._make_mask_from_positional_indexer([])
3371+
33613372
return self._mask_selected_obj(mask)
33623373

33633374
@final

0 commit comments

Comments
 (0)