Skip to content

Commit c7c0b3b

Browse files
rhshadrachmliu08
authored andcommitted
BUG: groupby.nth should be a filter (pandas-dev#49262)
1 parent 35b3d62 commit c7c0b3b

File tree

8 files changed

+179
-208
lines changed

8 files changed

+179
-208
lines changed

doc/source/user_guide/groupby.rst

+24-14
Original file line numberDiff line numberDiff line change
@@ -1354,9 +1354,14 @@ This shows the first or last n rows from each group.
13541354
Taking the nth row of each group
13551355
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
13561356

1357-
To select from a DataFrame or Series the nth item, use
1358-
:meth:`~pd.core.groupby.DataFrameGroupBy.nth`. This is a reduction method, and
1359-
will return a single row (or no row) per group if you pass an int for n:
1357+
To select the nth item from each group, use :meth:`.DataFrameGroupBy.nth` or
1358+
:meth:`.SeriesGroupBy.nth`. Arguments supplied can be any integer, lists of integers,
1359+
slices, or lists of slices; see below for examples. When the nth element of a group
1360+
does not exist an error is *not* raised; instead no corresponding rows are returned.
1361+
1362+
In general this operation acts as a filtration. In certain cases it will also return
1363+
one row per group, making it also a reduction. However because in general it can
1364+
return zero or multiple rows per group, pandas treats it as a filtration in all cases.
13601365

13611366
.. ipython:: python
13621367
@@ -1367,6 +1372,14 @@ will return a single row (or no row) per group if you pass an int for n:
13671372
g.nth(-1)
13681373
g.nth(1)
13691374
1375+
If the nth element of a group does not exist, then no corresponding row is included
1376+
in the result. In particular, if the specified ``n`` is larger than any group, the
1377+
result will be an empty DataFrame.
1378+
1379+
.. ipython:: python
1380+
1381+
g.nth(5)
1382+
13701383
If you want to select the nth not-null item, use the ``dropna`` kwarg. For a DataFrame this should be either ``'any'`` or ``'all'`` just like you would pass to dropna:
13711384

13721385
.. ipython:: python
@@ -1376,21 +1389,11 @@ If you want to select the nth not-null item, use the ``dropna`` kwarg. For a Dat
13761389
g.first()
13771390
13781391
# nth(-1) is the same as g.last()
1379-
g.nth(-1, dropna="any") # NaNs denote group exhausted when using dropna
1392+
g.nth(-1, dropna="any")
13801393
g.last()
13811394
13821395
g.B.nth(0, dropna="all")
13831396
1384-
As with other methods, passing ``as_index=False``, will achieve a filtration, which returns the grouped row.
1385-
1386-
.. ipython:: python
1387-
1388-
df = pd.DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=["A", "B"])
1389-
g = df.groupby("A", as_index=False)
1390-
1391-
g.nth(0)
1392-
g.nth(-1)
1393-
13941397
You can also select multiple rows from each group by specifying multiple nth values as a list of ints.
13951398

13961399
.. ipython:: python
@@ -1400,6 +1403,13 @@ You can also select multiple rows from each group by specifying multiple nth val
14001403
# get the first, 4th, and last date index for each month
14011404
df.groupby([df.index.year, df.index.month]).nth([0, 3, -1])
14021405
1406+
You may also use a slices or lists of slices.
1407+
1408+
.. ipython:: python
1409+
1410+
df.groupby([df.index.year, df.index.month]).nth[1:]
1411+
df.groupby([df.index.year, df.index.month]).nth[1:, :-1]
1412+
14031413
Enumerate group items
14041414
~~~~~~~~~~~~~~~~~~~~~
14051415

doc/source/whatsnew/v2.0.0.rst

+52-4
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ Notable bug fixes
7272

7373
These are bug fixes that might have notable behavior changes.
7474

75-
.. _whatsnew_200.notable_bug_fixes.notable_bug_fix1:
75+
.. _whatsnew_200.notable_bug_fixes.cumsum_cumprod_overflow:
7676

7777
:meth:`.GroupBy.cumsum` and :meth:`.GroupBy.cumprod` overflow instead of lossy casting to float
7878
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -102,10 +102,58 @@ We return incorrect results with the 6th value.
102102
103103
We overflow with the 7th value, but the 6th value is still correct.
104104

105-
.. _whatsnew_200.notable_bug_fixes.notable_bug_fix2:
105+
.. _whatsnew_200.notable_bug_fixes.groupby_nth_filter:
106106

107-
notable_bug_fix2
108-
^^^^^^^^^^^^^^^^
107+
:meth:`.DataFrameGroupBy.nth` and :meth:`.SeriesGroupBy.nth` now behave as filtrations
108+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
109+
110+
In previous versions of pandas, :meth:`.DataFrameGroupBy.nth` and
111+
:meth:`.SeriesGroupBy.nth` acted as if they were aggregations. However, for most
112+
inputs ``n``, they may return either zero or multiple rows per group. This means
113+
that they are filtrations, similar to e.g. :meth:`.DataFrameGroupBy.head`. pandas
114+
now treats them as filtrations (:issue:`13666`).
115+
116+
.. ipython:: python
117+
118+
df = pd.DataFrame({"a": [1, 1, 2, 1, 2], "b": [np.nan, 2.0, 3.0, 4.0, 5.0]})
119+
gb = df.groupby("a")
120+
121+
*Old Behavior*
122+
123+
.. code-block:: ipython
124+
125+
In [5]: gb.nth(n=1)
126+
Out[5]:
127+
A B
128+
1 1 2.0
129+
4 2 5.0
130+
131+
*New Behavior*
132+
133+
.. ipython:: python
134+
135+
gb.nth(n=1)
136+
137+
In particular, the index of the result is derived from the input by selecting
138+
the appropriate rows. Also, when ``n`` is larger than the group, no rows instead of
139+
``NaN`` is returned.
140+
141+
*Old Behavior*
142+
143+
.. code-block:: ipython
144+
145+
In [5]: gb.nth(n=3, dropna="any")
146+
Out[5]:
147+
B
148+
A
149+
1 NaN
150+
2 NaN
151+
152+
*New Behavior*
153+
154+
.. ipython:: python
155+
156+
gb.nth(n=3, dropna="any")
109157
110158
.. ---------------------------------------------------------------------------
111159
.. _whatsnew_200.api_breaking:

pandas/core/groupby/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ class OutputKey:
3737
"mean",
3838
"median",
3939
"min",
40-
"nth",
4140
"nunique",
4241
"prod",
4342
# as long as `quantile`'s signature accepts only
@@ -100,6 +99,7 @@ class OutputKey:
10099
"indices",
101100
"ndim",
102101
"ngroups",
102+
"nth",
103103
"ohlc",
104104
"pipe",
105105
"plot",

pandas/core/groupby/groupby.py

+43-88
Original file line numberDiff line numberDiff line change
@@ -2978,97 +2978,68 @@ def nth(
29782978
... 'B': [np.nan, 2, 3, 4, 5]}, columns=['A', 'B'])
29792979
>>> g = df.groupby('A')
29802980
>>> g.nth(0)
2981-
B
2982-
A
2983-
1 NaN
2984-
2 3.0
2981+
A B
2982+
0 1 NaN
2983+
2 2 3.0
29852984
>>> g.nth(1)
2986-
B
2987-
A
2988-
1 2.0
2989-
2 5.0
2985+
A B
2986+
1 1 2.0
2987+
4 2 5.0
29902988
>>> g.nth(-1)
2991-
B
2992-
A
2993-
1 4.0
2994-
2 5.0
2989+
A B
2990+
3 1 4.0
2991+
4 2 5.0
29952992
>>> g.nth([0, 1])
2996-
B
2997-
A
2998-
1 NaN
2999-
1 2.0
3000-
2 3.0
3001-
2 5.0
2993+
A B
2994+
0 1 NaN
2995+
1 1 2.0
2996+
2 2 3.0
2997+
4 2 5.0
30022998
>>> g.nth(slice(None, -1))
3003-
B
3004-
A
3005-
1 NaN
3006-
1 2.0
3007-
2 3.0
2999+
A B
3000+
0 1 NaN
3001+
1 1 2.0
3002+
2 2 3.0
30083003
30093004
Index notation may also be used
30103005
30113006
>>> g.nth[0, 1]
3012-
B
3013-
A
3014-
1 NaN
3015-
1 2.0
3016-
2 3.0
3017-
2 5.0
3007+
A B
3008+
0 1 NaN
3009+
1 1 2.0
3010+
2 2 3.0
3011+
4 2 5.0
30183012
>>> g.nth[:-1]
3019-
B
3020-
A
3021-
1 NaN
3022-
1 2.0
3023-
2 3.0
3013+
A B
3014+
0 1 NaN
3015+
1 1 2.0
3016+
2 2 3.0
30243017
3025-
Specifying `dropna` allows count ignoring ``NaN``
3018+
Specifying `dropna` allows ignoring ``NaN`` values
30263019
30273020
>>> g.nth(0, dropna='any')
3028-
B
3029-
A
3030-
1 2.0
3031-
2 3.0
3021+
A B
3022+
1 1 2.0
3023+
2 2 3.0
30323024
3033-
NaNs denote group exhausted when using dropna
3025+
When the specified ``n`` is larger than any of the groups, an
3026+
empty DataFrame is returned
30343027
30353028
>>> g.nth(3, dropna='any')
3036-
B
3037-
A
3038-
1 NaN
3039-
2 NaN
3040-
3041-
Specifying `as_index=False` in `groupby` keeps the original index.
3042-
3043-
>>> df.groupby('A', as_index=False).nth(1)
3044-
A B
3045-
1 1 2.0
3046-
4 2 5.0
3029+
Empty DataFrame
3030+
Columns: [A, B]
3031+
Index: []
30473032
"""
30483033
if not dropna:
3049-
with self._group_selection_context():
3050-
mask = self._make_mask_from_positional_indexer(n)
3034+
mask = self._make_mask_from_positional_indexer(n)
30513035

3052-
ids, _, _ = self.grouper.group_info
3036+
ids, _, _ = self.grouper.group_info
30533037

3054-
# Drop NA values in grouping
3055-
mask = mask & (ids != -1)
3038+
# Drop NA values in grouping
3039+
mask = mask & (ids != -1)
30563040

3057-
out = self._mask_selected_obj(mask)
3058-
if not self.as_index:
3059-
return out
3060-
3061-
result_index = self.grouper.result_index
3062-
if self.axis == 0:
3063-
out.index = result_index[ids[mask]]
3064-
if not self.observed and isinstance(result_index, CategoricalIndex):
3065-
out = out.reindex(result_index)
3066-
3067-
out = self._reindex_output(out)
3068-
else:
3069-
out.columns = result_index[ids[mask]]
3070-
3071-
return out.sort_index(axis=self.axis) if self.sort else out
3041+
out = self._mask_selected_obj(mask)
3042+
return out
30723043

30733044
# dropna is truthy
30743045
if not is_integer(n):
@@ -3085,7 +3056,6 @@ def nth(
30853056
# old behaviour, but with all and any support for DataFrames.
30863057
# modified in GH 7559 to have better perf
30873058
n = cast(int, n)
3088-
max_len = n if n >= 0 else -1 - n
30893059
dropped = self.obj.dropna(how=dropna, axis=self.axis)
30903060

30913061
# get a new grouper for our dropped obj
@@ -3115,22 +3085,7 @@ def nth(
31153085
grb = dropped.groupby(
31163086
grouper, as_index=self.as_index, sort=self.sort, axis=self.axis
31173087
)
3118-
sizes, result = grb.size(), grb.nth(n)
3119-
mask = (sizes < max_len)._values
3120-
3121-
# set the results which don't meet the criteria
3122-
if len(result) and mask.any():
3123-
result.loc[mask] = np.nan
3124-
3125-
# reset/reindex to the original groups
3126-
if len(self.obj) == len(dropped) or len(result) == len(
3127-
self.grouper.result_index
3128-
):
3129-
result.index = self.grouper.result_index
3130-
else:
3131-
result = result.reindex(self.grouper.result_index)
3132-
3133-
return result
3088+
return grb.nth(n)
31343089

31353090
@final
31363091
def quantile(

pandas/tests/groupby/test_categorical.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -563,11 +563,7 @@ def test_observed_nth():
563563
df = DataFrame({"cat": cat, "ser": ser})
564564

565565
result = df.groupby("cat", observed=False)["ser"].nth(0)
566-
567-
index = Categorical(["a", "b", "c"], categories=["a", "b", "c"])
568-
expected = Series([1, np.nan, np.nan], index=index, name="ser")
569-
expected.index.name = "cat"
570-
566+
expected = df["ser"].iloc[[0]]
571567
tm.assert_series_equal(result, expected)
572568

573569

pandas/tests/groupby/test_function.py

-1
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,6 @@ def test_median_empty_bins(observed):
405405
("last", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 4}]}),
406406
("min", {"df": [{"a": 1, "b": 1}, {"a": 2, "b": 3}]}),
407407
("max", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 4}]}),
408-
("nth", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 4}], "args": [1]}),
409408
("count", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 2}], "out_type": "int64"}),
410409
],
411410
)

pandas/tests/groupby/test_grouping.py

+2
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,8 @@ def test_groupby_with_single_column(self):
851851
exp = DataFrame(index=Index(["a", "b", "s"], name="a"))
852852
tm.assert_frame_equal(df.groupby("a").count(), exp)
853853
tm.assert_frame_equal(df.groupby("a").sum(), exp)
854+
855+
exp = df.iloc[[3, 4, 5]]
854856
tm.assert_frame_equal(df.groupby("a").nth(1), exp)
855857

856858
def test_gb_key_len_equal_axis_len(self):

0 commit comments

Comments
 (0)