Skip to content

Commit 4444870

Browse files
Backport PR pandas-dev#32124: BUG: Avoid ambiguous condition in GroupBy.first / last (pandas-dev#32199)
Co-authored-by: Daniel Saxton <[email protected]>
1 parent 1bc1d59 commit 4444870

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

doc/source/whatsnew/v1.0.2.rst

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ Bug fixes
7676

7777
- Fix bug in :meth:`DataFrame.convert_dtypes` for columns that were already using the ``"string"`` dtype (:issue:`31731`).
7878
- Fixed bug in setting values using a slice indexer with string dtype (:issue:`31772`)
79+
- Fixed bug where :meth:`GroupBy.first` and :meth:`GroupBy.last` would raise a ``TypeError`` when groups contained ``pd.NA`` in a column of object dtype (:issue:`32123`)
7980
- Fix bug in :meth:`Series.convert_dtypes` for series with mix of integers and strings (:issue:`32117`)
8081

8182
.. ---------------------------------------------------------------------------

pandas/_libs/groupby.pyx

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ from pandas._libs.algos cimport (swap, TiebreakEnumType, TIEBREAK_AVERAGE,
2222
from pandas._libs.algos import (take_2d_axis1_float64_float64,
2323
groupsort_indexer, tiebreakers)
2424

25+
from pandas._libs.missing cimport checknull
26+
2527
cdef int64_t NPY_NAT = get_nat()
2628
_int64_max = np.iinfo(np.int64).max
2729

@@ -888,7 +890,7 @@ def group_last(rank_t[:, :] out,
888890
for j in range(K):
889891
val = values[i, j]
890892

891-
if val == val:
893+
if not checknull(val):
892894
# NB: use _treat_as_na here once
893895
# conditional-nogil is available.
894896
nobs[lab, j] += 1
@@ -977,7 +979,7 @@ def group_nth(rank_t[:, :] out,
977979
for j in range(K):
978980
val = values[i, j]
979981

980-
if val == val:
982+
if not checknull(val):
981983
# NB: use _treat_as_na here once
982984
# conditional-nogil is available.
983985
nobs[lab, j] += 1

pandas/tests/groupby/test_nth.py

+40
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,46 @@ def test_first_last_nth(df):
5454
tm.assert_frame_equal(result, expected)
5555

5656

57+
@pytest.mark.parametrize("method", ["first", "last"])
58+
def test_first_last_with_na_object(method, nulls_fixture):
59+
# https://github.com/pandas-dev/pandas/issues/32123
60+
groups = pd.DataFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, nulls_fixture]}).groupby(
61+
"a"
62+
)
63+
result = getattr(groups, method)()
64+
65+
if method == "first":
66+
values = [1, 3]
67+
else:
68+
values = [2, 3]
69+
70+
values = np.array(values, dtype=result["b"].dtype)
71+
idx = pd.Index([1, 2], name="a")
72+
expected = pd.DataFrame({"b": values}, index=idx)
73+
74+
tm.assert_frame_equal(result, expected)
75+
76+
77+
@pytest.mark.parametrize("index", [0, -1])
78+
def test_nth_with_na_object(index, nulls_fixture):
79+
# https://github.com/pandas-dev/pandas/issues/32123
80+
groups = pd.DataFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, nulls_fixture]}).groupby(
81+
"a"
82+
)
83+
result = groups.nth(index)
84+
85+
if index == 0:
86+
values = [1, 3]
87+
else:
88+
values = [2, nulls_fixture]
89+
90+
values = np.array(values, dtype=result["b"].dtype)
91+
idx = pd.Index([1, 2], name="a")
92+
expected = pd.DataFrame({"b": values}, index=idx)
93+
94+
tm.assert_frame_equal(result, expected)
95+
96+
5797
def test_first_last_nth_dtypes(df_mixed_floats):
5898

5999
df = df_mixed_floats.copy()

0 commit comments

Comments
 (0)