From e89faa8888d05720312e8e6b284028997be00dc5 Mon Sep 17 00:00:00 2001 From: richard Date: Mon, 26 Dec 2022 21:31:00 -0500 Subject: [PATCH 1/2] BUG: groupby.nth with NA values in groupings after subsetting to SeriesGroupBy --- doc/source/whatsnew/v2.0.0.rst | 1 + pandas/core/groupby/groupby.py | 6 +++++- pandas/tests/groupby/test_nth.py | 22 ++++++++++++++++++++-- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index 12b0d90e68ab9..1c7e9f09fcf21 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -924,6 +924,7 @@ Groupby/resample/rolling - Bug in :meth:`.SeriesGroupBy.describe` with ``as_index=False`` would have the incorrect shape (:issue:`49256`) - Bug in :class:`.DataFrameGroupBy` and :class:`.SeriesGroupBy` with ``dropna=False`` would drop NA values when the grouper was categorical (:issue:`36327`) - Bug in :meth:`.SeriesGroupBy.nunique` would incorrectly raise when the grouper was an empty categorical and ``observed=True`` (:issue:`21334`) +- Bug in :meth:`.SeriesGroupBy.nth` would raise when grouper contained NA values after subsetting from a :class:`DataFrameGroupBy` (:issue:`26454`) Reshaping ^^^^^^^^^ diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 11e8769615470..530f620a202c9 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2933,7 +2933,11 @@ def _nth( # (e.g. we have selected out # a column that is not in the current object) axis = self.grouper.axis - grouper = axis[axis.isin(dropped.index)] + grouper = self.grouper.codes_info[axis.isin(dropped.index)] + if self.grouper.has_dropped_na: + # Null groups need to be encoded as -1 when passed to groupby + grouper = grouper.astype(object) + grouper[grouper == -1] = None else: diff --git a/pandas/tests/groupby/test_nth.py b/pandas/tests/groupby/test_nth.py index de5025b998b30..77422c28d356f 100644 --- a/pandas/tests/groupby/test_nth.py +++ b/pandas/tests/groupby/test_nth.py @@ -603,8 +603,11 @@ def test_nth_column_order(): def test_nth_nan_in_grouper(dropna): # GH 26011 df = DataFrame( - [[np.nan, 0, 1], ["abc", 2, 3], [np.nan, 4, 5], ["def", 6, 7], [np.nan, 8, 9]], - columns=list("abc"), + { + "a": [np.nan, "a", np.nan, "b", np.nan], + "b": [0, 2, 4, 6, 8], + "c": [1, 3, 5, 7, 9], + } ) result = df.groupby("a").nth(0, dropna=dropna) expected = df.iloc[[1, 3]] @@ -612,6 +615,21 @@ def test_nth_nan_in_grouper(dropna): tm.assert_frame_equal(result, expected) +@pytest.mark.parametrize("dropna", [None, "any", "all"]) +def test_nth_nan_in_grouper_series(dropna): + # GH 26454 + df = DataFrame( + { + "a": [np.nan, "a", np.nan, "b", np.nan], + "b": [0, 2, 4, 6, 8], + } + ) + result = df.groupby("a")["b"].nth(0, dropna=dropna) + expected = df["b"].iloc[[1, 3]] + + tm.assert_series_equal(result, expected) + + def test_first_categorical_and_datetime_data_nat(): # GH 20520 df = DataFrame( From 91b302b4b4c2c624cdec84b254479ecc7418c3b3 Mon Sep 17 00:00:00 2001 From: richard Date: Mon, 26 Dec 2022 21:52:04 -0500 Subject: [PATCH 2/2] Use Int64 Index instead of object --- pandas/core/groupby/groupby.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 530f620a202c9..d7b983367b596 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -44,6 +44,7 @@ class providing the base-class of operations. ) from pandas._libs.algos import rank_1d import pandas._libs.groupby as libgroupby +from pandas._libs.missing import NA from pandas._typing import ( AnyArrayLike, ArrayLike, @@ -2935,9 +2936,12 @@ def _nth( axis = self.grouper.axis grouper = self.grouper.codes_info[axis.isin(dropped.index)] if self.grouper.has_dropped_na: - # Null groups need to be encoded as -1 when passed to groupby - grouper = grouper.astype(object) - grouper[grouper == -1] = None + # Null groups need to still be encoded as -1 when passed to groupby + nulls = grouper == -1 + # error: No overload variant of "where" matches argument types + # "Any", "NAType", "Any" + values = np.where(nulls, NA, grouper) # type: ignore[call-overload] + grouper = Index(values, dtype="Int64") else: