Skip to content

Commit 7120725

Browse files
WillAydjreback
authored andcommitted
Fix Bug with NA value in Grouping for Groupby.nth (#26152)
1 parent cc3b2f0 commit 7120725

File tree

3 files changed

+50
-21
lines changed

3 files changed

+50
-21
lines changed

doc/source/whatsnew/v0.25.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ Groupby/Resample/Rolling
398398
- Bug in :meth:`pandas.core.window.Rolling.count` and `pandas.core.window.Expanding.count` was previously ignoring the axis keyword (:issue:`13503`)
399399
- Bug in :meth:`pandas.core.groupby.GroupBy.idxmax` and :meth:`pandas.core.groupby.GroupBy.idxmin` with datetime column would return incorrect dtype (:issue:`25444`, :issue:`15306`)
400400
- Bug in :meth:`pandas.core.groupby.GroupBy.cumsum`, :meth:`pandas.core.groupby.GroupBy.cumprod`, :meth:`pandas.core.groupby.GroupBy.cummin` and :meth:`pandas.core.groupby.GroupBy.cummax` with categorical column having absent categories, would return incorrect result or segfault (:issue:`16771`)
401+
- Bug in :meth:`pandas.core.groupby.GroupBy.nth` where NA values in the grouping would return incorrect results (:issue:`26011`)
401402

402403

403404
Reshaping

pandas/core/groupby/groupby.py

+32-21
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class providing the base-class of operations.
1212
import datetime
1313
from functools import partial, wraps
1414
import types
15-
from typing import FrozenSet, Optional, Tuple, Type
15+
from typing import FrozenSet, List, Optional, Tuple, Type, Union
1616
import warnings
1717

1818
import numpy as np
@@ -1543,15 +1543,16 @@ def backfill(self, limit=None):
15431543

15441544
@Substitution(name='groupby')
15451545
@Substitution(see_also=_common_see_also)
1546-
def nth(self, n, dropna=None):
1546+
def nth(self,
1547+
n: Union[int, List[int]],
1548+
dropna: Optional[str] = None) -> DataFrame:
15471549
"""
15481550
Take the nth row from each group if n is an int, or a subset of rows
15491551
if n is a list of ints.
15501552
15511553
If dropna, will take the nth non-null row, dropna is either
1552-
Truthy (if a Series) or 'all', 'any' (if a DataFrame);
1553-
this is equivalent to calling dropna(how=dropna) before the
1554-
groupby.
1554+
'all' or 'any'; this is equivalent to calling dropna(how=dropna)
1555+
before the groupby.
15551556
15561557
Parameters
15571558
----------
@@ -1614,34 +1615,43 @@ def nth(self, n, dropna=None):
16141615
4 2 5.0
16151616
"""
16161617

1617-
if isinstance(n, int):
1618-
nth_values = [n]
1619-
elif isinstance(n, (set, list, tuple)):
1620-
nth_values = list(set(n))
1621-
if dropna is not None:
1622-
raise ValueError(
1623-
"dropna option with a list of nth values is not supported")
1624-
else:
1618+
valid_containers = (set, list, tuple)
1619+
if not isinstance(n, (valid_containers, int)):
16251620
raise TypeError("n needs to be an int or a list/set/tuple of ints")
16261621

1627-
nth_values = np.array(nth_values, dtype=np.intp)
1628-
self._set_group_selection()
1629-
16301622
if not dropna:
1631-
mask_left = np.in1d(self._cumcount_array(), nth_values)
1623+
1624+
if isinstance(n, int):
1625+
nth_values = [n]
1626+
elif isinstance(n, valid_containers):
1627+
nth_values = list(set(n))
1628+
1629+
nth_array = np.array(nth_values, dtype=np.intp)
1630+
self._set_group_selection()
1631+
1632+
mask_left = np.in1d(self._cumcount_array(), nth_array)
16321633
mask_right = np.in1d(self._cumcount_array(ascending=False) + 1,
1633-
-nth_values)
1634+
-nth_array)
16341635
mask = mask_left | mask_right
16351636

1637+
ids, _, _ = self.grouper.group_info
1638+
1639+
# Drop NA values in grouping
1640+
mask = mask & (ids != -1)
1641+
16361642
out = self._selected_obj[mask]
16371643
if not self.as_index:
16381644
return out
16391645

1640-
ids, _, _ = self.grouper.group_info
16411646
out.index = self.grouper.result_index[ids[mask]]
16421647

16431648
return out.sort_index() if self.sort else out
16441649

1650+
# dropna is truthy
1651+
if isinstance(n, valid_containers):
1652+
raise ValueError(
1653+
"dropna option with a list of nth values is not supported")
1654+
16451655
if dropna not in ['any', 'all']:
16461656
if isinstance(self._selected_obj, Series) and dropna is True:
16471657
warnings.warn("the dropna={dropna} keyword is deprecated,"
@@ -1676,15 +1686,16 @@ def nth(self, n, dropna=None):
16761686

16771687
else:
16781688

1679-
# create a grouper with the original parameters, but on the dropped
1689+
# create a grouper with the original parameters, but on dropped
16801690
# object
16811691
from pandas.core.groupby.grouper import _get_grouper
16821692
grouper, _, _ = _get_grouper(dropped, key=self.keys,
16831693
axis=self.axis, level=self.level,
16841694
sort=self.sort,
16851695
mutated=self.mutated)
16861696

1687-
grb = dropped.groupby(grouper, as_index=self.as_index, sort=self.sort)
1697+
grb = dropped.groupby(
1698+
grouper, as_index=self.as_index, sort=self.sort)
16881699
sizes, result = grb.size(), grb.nth(n)
16891700
mask = (sizes < max_len).values
16901701

pandas/tests/groupby/test_nth.py

+17
Original file line numberDiff line numberDiff line change
@@ -432,3 +432,20 @@ def test_nth_column_order():
432432
columns=['C', 'B'],
433433
index=Index([1, 2], name='A'))
434434
assert_frame_equal(result, expected)
435+
436+
437+
@pytest.mark.parametrize("dropna", [None, 'any', 'all'])
438+
def test_nth_nan_in_grouper(dropna):
439+
# GH 26011
440+
df = DataFrame([
441+
[np.nan, 0, 1],
442+
['abc', 2, 3],
443+
[np.nan, 4, 5],
444+
['def', 6, 7],
445+
[np.nan, 8, 9],
446+
], columns=list('abc'))
447+
result = df.groupby('a').nth(0, dropna=dropna)
448+
expected = pd.DataFrame([[2, 3], [6, 7]], columns=list('bc'),
449+
index=Index(['abc', 'def'], name='a'))
450+
451+
assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)