diff --git a/doc/source/whatsnew/v1.0.2.rst b/doc/source/whatsnew/v1.0.2.rst index affe019d0ac86..82d603cfb8c15 100644 --- a/doc/source/whatsnew/v1.0.2.rst +++ b/doc/source/whatsnew/v1.0.2.rst @@ -76,6 +76,7 @@ Bug fixes - Fix bug in :meth:`DataFrame.convert_dtypes` for columns that were already using the ``"string"`` dtype (:issue:`31731`). - Fixed bug in setting values using a slice indexer with string dtype (:issue:`31772`) +- Fixed bug where :meth:`GroupBy.first` and :meth:`GroupBy.last` would raise a ``TypeError`` when groups contained ``pd.NA`` in a column of object dtype (:issue:`32123`) - Fix bug in :meth:`Series.convert_dtypes` for series with mix of integers and strings (:issue:`32117`) .. --------------------------------------------------------------------------- diff --git a/pandas/_libs/groupby.pyx b/pandas/_libs/groupby.pyx index edc44f1c94589..27b3095d8cb4f 100644 --- a/pandas/_libs/groupby.pyx +++ b/pandas/_libs/groupby.pyx @@ -22,6 +22,8 @@ from pandas._libs.algos cimport (swap, TiebreakEnumType, TIEBREAK_AVERAGE, from pandas._libs.algos import (take_2d_axis1_float64_float64, groupsort_indexer, tiebreakers) +from pandas._libs.missing cimport checknull + cdef int64_t NPY_NAT = get_nat() _int64_max = np.iinfo(np.int64).max @@ -887,7 +889,7 @@ def group_last(rank_t[:, :] out, for j in range(K): val = values[i, j] - if val == val: + if not checknull(val): # NB: use _treat_as_na here once # conditional-nogil is available. nobs[lab, j] += 1 @@ -976,7 +978,7 @@ def group_nth(rank_t[:, :] out, for j in range(K): val = values[i, j] - if val == val: + if not checknull(val): # NB: use _treat_as_na here once # conditional-nogil is available. nobs[lab, j] += 1 diff --git a/pandas/tests/groupby/test_nth.py b/pandas/tests/groupby/test_nth.py index 0f850f2e94581..b1476f1059d84 100644 --- a/pandas/tests/groupby/test_nth.py +++ b/pandas/tests/groupby/test_nth.py @@ -54,6 +54,46 @@ def test_first_last_nth(df): tm.assert_frame_equal(result, expected) +@pytest.mark.parametrize("method", ["first", "last"]) +def test_first_last_with_na_object(method, nulls_fixture): + # https://github.com/pandas-dev/pandas/issues/32123 + groups = pd.DataFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, nulls_fixture]}).groupby( + "a" + ) + result = getattr(groups, method)() + + if method == "first": + values = [1, 3] + else: + values = [2, 3] + + values = np.array(values, dtype=result["b"].dtype) + idx = pd.Index([1, 2], name="a") + expected = pd.DataFrame({"b": values}, index=idx) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("index", [0, -1]) +def test_nth_with_na_object(index, nulls_fixture): + # https://github.com/pandas-dev/pandas/issues/32123 + groups = pd.DataFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, nulls_fixture]}).groupby( + "a" + ) + result = groups.nth(index) + + if index == 0: + values = [1, 3] + else: + values = [2, nulls_fixture] + + values = np.array(values, dtype=result["b"].dtype) + idx = pd.Index([1, 2], name="a") + expected = pd.DataFrame({"b": values}, index=idx) + + tm.assert_frame_equal(result, expected) + + def test_first_last_nth_dtypes(df_mixed_floats): df = df_mixed_floats.copy()