Skip to content

BUG: groupby.nth should be a filter #49262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions doc/source/user_guide/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1354,9 +1354,14 @@ This shows the first or last n rows from each group.
Taking the nth row of each group
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To select from a DataFrame or Series the nth item, use
:meth:`~pd.core.groupby.DataFrameGroupBy.nth`. This is a reduction method, and
will return a single row (or no row) per group if you pass an int for n:
To select the nth item from each group, use :meth:`.DataFrameGroupBy.nth` or
:meth:`.SeriesGroupBy.nth`. Arguments supplied can be any integer, lists of integers,
slices, or lists of slices; see below for examples. When the nth element of a group
does not exist an error is *not* raised; instead no corresponding rows are returned.

In general this operation acts as a filtration. In certain cases it will also return
one row per group, making it also a reduction. However because in general it can
return zero or multiple rows per group, pandas treats it as a filtration in all cases.

.. ipython:: python

Expand All @@ -1367,6 +1372,14 @@ will return a single row (or no row) per group if you pass an int for n:
g.nth(-1)
g.nth(1)

If the nth element of a group does not exist, then no corresponding row is included
in the result. In particular, if the specified ``n`` is larger than any group, the
result will be an empty DataFrame.

.. ipython:: python

g.nth(5)

If you want to select the nth not-null item, use the ``dropna`` kwarg. For a DataFrame this should be either ``'any'`` or ``'all'`` just like you would pass to dropna:

.. ipython:: python
Expand All @@ -1376,21 +1389,11 @@ If you want to select the nth not-null item, use the ``dropna`` kwarg. For a Dat
g.first()

# nth(-1) is the same as g.last()
g.nth(-1, dropna="any") # NaNs denote group exhausted when using dropna
g.nth(-1, dropna="any")
g.last()

g.B.nth(0, dropna="all")

As with other methods, passing ``as_index=False``, will achieve a filtration, which returns the grouped row.

.. ipython:: python

df = pd.DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=["A", "B"])
g = df.groupby("A", as_index=False)

g.nth(0)
g.nth(-1)

You can also select multiple rows from each group by specifying multiple nth values as a list of ints.

.. ipython:: python
Expand All @@ -1400,6 +1403,13 @@ You can also select multiple rows from each group by specifying multiple nth val
# get the first, 4th, and last date index for each month
df.groupby([df.index.year, df.index.month]).nth([0, 3, -1])

You may also use a slices or lists of slices.

.. ipython:: python

df.groupby([df.index.year, df.index.month]).nth[1:]
df.groupby([df.index.year, df.index.month]).nth[1:, :-1]

Enumerate group items
~~~~~~~~~~~~~~~~~~~~~

Expand Down
56 changes: 52 additions & 4 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Notable bug fixes

These are bug fixes that might have notable behavior changes.

.. _whatsnew_200.notable_bug_fixes.notable_bug_fix1:
.. _whatsnew_200.notable_bug_fixes.cumsum_cumprod_overflow:

:meth:`.GroupBy.cumsum` and :meth:`.GroupBy.cumprod` overflow instead of lossy casting to float
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -102,10 +102,58 @@ We return incorrect results with the 6th value.

We overflow with the 7th value, but the 6th value is still correct.

.. _whatsnew_200.notable_bug_fixes.notable_bug_fix2:
.. _whatsnew_200.notable_bug_fixes.groupby_nth_filter:

notable_bug_fix2
^^^^^^^^^^^^^^^^
:meth:`.DataFrameGroupBy.nth` and :meth:`.SeriesGroupBy.nth` now behave as filtrations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In previous versions of pandas, :meth:`.DataFrameGroupBy.nth` and
:meth:`.SeriesGroupBy.nth` acted as if they were aggregations. However, for most
inputs ``n``, they may return either zero or multiple rows per group. This means
that they are filtrations, similar to e.g. :meth:`.DataFrameGroupBy.head`. pandas
now treats them as filtrations (:issue:`13666`).

.. ipython:: python

df = pd.DataFrame({"a": [1, 1, 2, 1, 2], "b": [np.nan, 2.0, 3.0, 4.0, 5.0]})
gb = df.groupby("a")

*Old Behavior*

.. code-block:: ipython

In [5]: gb.nth(n=1)
Out[5]:
A B
1 1 2.0
4 2 5.0

*New Behavior*

.. ipython:: python

gb.nth(n=1)

In particular, the index of the result is derived from the input by selecting
the appropriate rows. Also, when ``n`` is larger than the group, no rows instead of
``NaN`` is returned.

*Old Behavior*

.. code-block:: ipython

In [5]: gb.nth(n=3, dropna="any")
Out[5]:
B
A
1 NaN
2 NaN

*New Behavior*

.. ipython:: python

gb.nth(n=3, dropna="any")

.. ---------------------------------------------------------------------------
.. _whatsnew_200.api_breaking:
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/groupby/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class OutputKey:
"mean",
"median",
"min",
"nth",
"nunique",
"prod",
# as long as `quantile`'s signature accepts only
Expand Down Expand Up @@ -100,6 +99,7 @@ class OutputKey:
"indices",
"ndim",
"ngroups",
"nth",
"ohlc",
"pipe",
"plot",
Expand Down
131 changes: 43 additions & 88 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2978,97 +2978,68 @@ def nth(
... 'B': [np.nan, 2, 3, 4, 5]}, columns=['A', 'B'])
>>> g = df.groupby('A')
>>> g.nth(0)
B
A
1 NaN
2 3.0
A B
0 1 NaN
2 2 3.0
>>> g.nth(1)
B
A
1 2.0
2 5.0
A B
1 1 2.0
4 2 5.0
>>> g.nth(-1)
B
A
1 4.0
2 5.0
A B
3 1 4.0
4 2 5.0
>>> g.nth([0, 1])
B
A
1 NaN
1 2.0
2 3.0
2 5.0
A B
0 1 NaN
1 1 2.0
2 2 3.0
4 2 5.0
>>> g.nth(slice(None, -1))
B
A
1 NaN
1 2.0
2 3.0
A B
0 1 NaN
1 1 2.0
2 2 3.0

Index notation may also be used

>>> g.nth[0, 1]
B
A
1 NaN
1 2.0
2 3.0
2 5.0
A B
0 1 NaN
1 1 2.0
2 2 3.0
4 2 5.0
>>> g.nth[:-1]
B
A
1 NaN
1 2.0
2 3.0
A B
0 1 NaN
1 1 2.0
2 2 3.0

Specifying `dropna` allows count ignoring ``NaN``
Specifying `dropna` allows ignoring ``NaN`` values

>>> g.nth(0, dropna='any')
B
A
1 2.0
2 3.0
A B
1 1 2.0
2 2 3.0

NaNs denote group exhausted when using dropna
When the specified ``n`` is larger than any of the groups, an
empty DataFrame is returned

>>> g.nth(3, dropna='any')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think this example should be shown in the whatsnew? On the surface, this example appears to be quite different from before

B
A
1 NaN
2 NaN

Specifying `as_index=False` in `groupby` keeps the original index.

>>> df.groupby('A', as_index=False).nth(1)
A B
1 1 2.0
4 2 5.0
Empty DataFrame
Columns: [A, B]
Index: []
"""
if not dropna:
with self._group_selection_context():
mask = self._make_mask_from_positional_indexer(n)
mask = self._make_mask_from_positional_indexer(n)

ids, _, _ = self.grouper.group_info
ids, _, _ = self.grouper.group_info

# Drop NA values in grouping
mask = mask & (ids != -1)
# Drop NA values in grouping
mask = mask & (ids != -1)

out = self._mask_selected_obj(mask)
if not self.as_index:
return out

result_index = self.grouper.result_index
if self.axis == 0:
out.index = result_index[ids[mask]]
if not self.observed and isinstance(result_index, CategoricalIndex):
out = out.reindex(result_index)

out = self._reindex_output(out)
else:
out.columns = result_index[ids[mask]]

return out.sort_index(axis=self.axis) if self.sort else out
out = self._mask_selected_obj(mask)
return out

# dropna is truthy
if not is_integer(n):
Expand All @@ -3085,7 +3056,6 @@ def nth(
# old behaviour, but with all and any support for DataFrames.
# modified in GH 7559 to have better perf
n = cast(int, n)
max_len = n if n >= 0 else -1 - n
dropped = self.obj.dropna(how=dropna, axis=self.axis)

# get a new grouper for our dropped obj
Expand Down Expand Up @@ -3115,22 +3085,7 @@ def nth(
grb = dropped.groupby(
grouper, as_index=self.as_index, sort=self.sort, axis=self.axis
)
sizes, result = grb.size(), grb.nth(n)
mask = (sizes < max_len)._values

# set the results which don't meet the criteria
if len(result) and mask.any():
result.loc[mask] = np.nan

# reset/reindex to the original groups
if len(self.obj) == len(dropped) or len(result) == len(
self.grouper.result_index
):
result.index = self.grouper.result_index
else:
result = result.reindex(self.grouper.result_index)

return result
return grb.nth(n)

@final
def quantile(
Expand Down
6 changes: 1 addition & 5 deletions pandas/tests/groupby/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,11 +563,7 @@ def test_observed_nth():
df = DataFrame({"cat": cat, "ser": ser})

result = df.groupby("cat", observed=False)["ser"].nth(0)

index = Categorical(["a", "b", "c"], categories=["a", "b", "c"])
expected = Series([1, np.nan, np.nan], index=index, name="ser")
expected.index.name = "cat"

expected = df["ser"].iloc[[0]]
tm.assert_series_equal(result, expected)


Expand Down
1 change: 0 additions & 1 deletion pandas/tests/groupby/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,6 @@ def test_median_empty_bins(observed):
("last", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 4}]}),
("min", {"df": [{"a": 1, "b": 1}, {"a": 2, "b": 3}]}),
("max", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 4}]}),
("nth", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 4}], "args": [1]}),
("count", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 2}], "out_type": "int64"}),
],
)
Expand Down
2 changes: 2 additions & 0 deletions pandas/tests/groupby/test_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,8 @@ def test_groupby_with_single_column(self):
exp = DataFrame(index=Index(["a", "b", "s"], name="a"))
tm.assert_frame_equal(df.groupby("a").count(), exp)
tm.assert_frame_equal(df.groupby("a").sum(), exp)

exp = df.iloc[[3, 4, 5]]
tm.assert_frame_equal(df.groupby("a").nth(1), exp)

def test_gb_key_len_equal_axis_len(self):
Expand Down
Loading