diff --git a/doc/source/whatsnew/v0.25.0.rst b/doc/source/whatsnew/v0.25.0.rst index b2a379d9fe6f5..5d611069ebd0a 100644 --- a/doc/source/whatsnew/v0.25.0.rst +++ b/doc/source/whatsnew/v0.25.0.rst @@ -393,6 +393,7 @@ Groupby/Resample/Rolling - Bug in :meth:`pandas.core.window.Rolling.count` and `pandas.core.window.Expanding.count` was previously ignoring the axis keyword (:issue:`13503`) - Bug in :meth:`pandas.core.groupby.GroupBy.idxmax` and :meth:`pandas.core.groupby.GroupBy.idxmin` with datetime column would return incorrect dtype (:issue:`25444`, :issue:`15306`) - Bug in :meth:`pandas.core.groupby.GroupBy.cumsum`, :meth:`pandas.core.groupby.GroupBy.cumprod`, :meth:`pandas.core.groupby.GroupBy.cummin` and :meth:`pandas.core.groupby.GroupBy.cummax` with categorical column having absent categories, would return incorrect result or segfault (:issue:`16771`) +- Bug in :meth:`pandas.core.groupby.GroupBy.nth` where NA values in the grouping would return incorrect results (:issue:`26011`) Reshaping diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index bd8a8852964e3..945885e34fa1e 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -12,7 +12,7 @@ class providing the base-class of operations. import datetime from functools import partial, wraps import types -from typing import FrozenSet, Optional, Tuple, Type +from typing import FrozenSet, List, Optional, Tuple, Type, Union import warnings import numpy as np @@ -1546,15 +1546,16 @@ def backfill(self, limit=None): @Substitution(name='groupby') @Substitution(see_also=_common_see_also) - def nth(self, n, dropna=None): + def nth(self, + n: Union[int, List[int]], + dropna: Optional[str] = None) -> DataFrame: """ Take the nth row from each group if n is an int, or a subset of rows if n is a list of ints. If dropna, will take the nth non-null row, dropna is either - Truthy (if a Series) or 'all', 'any' (if a DataFrame); - this is equivalent to calling dropna(how=dropna) before the - groupby. + 'all' or 'any'; this is equivalent to calling dropna(how=dropna) + before the groupby. Parameters ---------- @@ -1617,34 +1618,43 @@ def nth(self, n, dropna=None): 4 2 5.0 """ - if isinstance(n, int): - nth_values = [n] - elif isinstance(n, (set, list, tuple)): - nth_values = list(set(n)) - if dropna is not None: - raise ValueError( - "dropna option with a list of nth values is not supported") - else: + valid_containers = (set, list, tuple) + if not isinstance(n, (valid_containers, int)): raise TypeError("n needs to be an int or a list/set/tuple of ints") - nth_values = np.array(nth_values, dtype=np.intp) - self._set_group_selection() - if not dropna: - mask_left = np.in1d(self._cumcount_array(), nth_values) + + if isinstance(n, int): + nth_values = [n] + elif isinstance(n, valid_containers): + nth_values = list(set(n)) + + nth_array = np.array(nth_values, dtype=np.intp) + self._set_group_selection() + + mask_left = np.in1d(self._cumcount_array(), nth_array) mask_right = np.in1d(self._cumcount_array(ascending=False) + 1, - -nth_values) + -nth_array) mask = mask_left | mask_right + ids, _, _ = self.grouper.group_info + + # Drop NA values in grouping + mask = mask & (ids != -1) + out = self._selected_obj[mask] if not self.as_index: return out - ids, _, _ = self.grouper.group_info out.index = self.grouper.result_index[ids[mask]] return out.sort_index() if self.sort else out + # dropna is truthy + if isinstance(n, valid_containers): + raise ValueError( + "dropna option with a list of nth values is not supported") + if dropna not in ['any', 'all']: if isinstance(self._selected_obj, Series) and dropna is True: warnings.warn("the dropna={dropna} keyword is deprecated," @@ -1679,7 +1689,7 @@ def nth(self, n, dropna=None): else: - # create a grouper with the original parameters, but on the dropped + # create a grouper with the original parameters, but on dropped # object from pandas.core.groupby.grouper import _get_grouper grouper, _, _ = _get_grouper(dropped, key=self.keys, @@ -1687,7 +1697,8 @@ def nth(self, n, dropna=None): sort=self.sort, mutated=self.mutated) - grb = dropped.groupby(grouper, as_index=self.as_index, sort=self.sort) + grb = dropped.groupby( + grouper, as_index=self.as_index, sort=self.sort) sizes, result = grb.size(), grb.nth(n) mask = (sizes < max_len).values diff --git a/pandas/tests/groupby/test_nth.py b/pandas/tests/groupby/test_nth.py index 7a3d189d3020e..6d07ab0008adb 100644 --- a/pandas/tests/groupby/test_nth.py +++ b/pandas/tests/groupby/test_nth.py @@ -434,3 +434,20 @@ def test_nth_column_order(): columns=['C', 'B'], index=Index([1, 2], name='A')) assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dropna", [None, 'any', 'all']) +def test_nth_nan_in_grouper(dropna): + # GH 26011 + df = DataFrame([ + [np.nan, 0, 1], + ['abc', 2, 3], + [np.nan, 4, 5], + ['def', 6, 7], + [np.nan, 8, 9], + ], columns=list('abc')) + result = df.groupby('a').nth(0, dropna=dropna) + expected = pd.DataFrame([[2, 3], [6, 7]], columns=list('bc'), + index=Index(['abc', 'def'], name='a')) + + assert_frame_equal(result, expected)