Skip to content

Commit c4db4a9

Browse files
dsaxtonroberthdevries
authored andcommitted
BUG: Avoid ambiguous condition in GroupBy.first / last (pandas-dev#32124)
1 parent e740263 commit c4db4a9

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

doc/source/whatsnew/v1.0.2.rst

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Bug fixes
7777

7878
- Fix bug in :meth:`DataFrame.convert_dtypes` for columns that were already using the ``"string"`` dtype (:issue:`31731`).
7979
- Fixed bug in setting values using a slice indexer with string dtype (:issue:`31772`)
80+
- Fixed bug where :meth:`GroupBy.first` and :meth:`GroupBy.last` would raise a ``TypeError`` when groups contained ``pd.NA`` in a column of object dtype (:issue:`32123`)
8081
- Fix bug in :meth:`Series.convert_dtypes` for series with mix of integers and strings (:issue:`32117`)
8182

8283
.. ---------------------------------------------------------------------------

pandas/_libs/groupby.pyx

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ from pandas._libs.algos cimport (swap, TiebreakEnumType, TIEBREAK_AVERAGE,
2222
from pandas._libs.algos import (take_2d_axis1_float64_float64,
2323
groupsort_indexer, tiebreakers)
2424

25+
from pandas._libs.missing cimport checknull
26+
2527
cdef int64_t NPY_NAT = get_nat()
2628
_int64_max = np.iinfo(np.int64).max
2729

@@ -887,7 +889,7 @@ def group_last(rank_t[:, :] out,
887889
for j in range(K):
888890
val = values[i, j]
889891

890-
if val == val:
892+
if not checknull(val):
891893
# NB: use _treat_as_na here once
892894
# conditional-nogil is available.
893895
nobs[lab, j] += 1
@@ -976,7 +978,7 @@ def group_nth(rank_t[:, :] out,
976978
for j in range(K):
977979
val = values[i, j]
978980

979-
if val == val:
981+
if not checknull(val):
980982
# NB: use _treat_as_na here once
981983
# conditional-nogil is available.
982984
nobs[lab, j] += 1

pandas/tests/groupby/test_nth.py

+40
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,46 @@ def test_first_last_nth(df):
5454
tm.assert_frame_equal(result, expected)
5555

5656

57+
@pytest.mark.parametrize("method", ["first", "last"])
58+
def test_first_last_with_na_object(method, nulls_fixture):
59+
# https://github.com/pandas-dev/pandas/issues/32123
60+
groups = pd.DataFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, nulls_fixture]}).groupby(
61+
"a"
62+
)
63+
result = getattr(groups, method)()
64+
65+
if method == "first":
66+
values = [1, 3]
67+
else:
68+
values = [2, 3]
69+
70+
values = np.array(values, dtype=result["b"].dtype)
71+
idx = pd.Index([1, 2], name="a")
72+
expected = pd.DataFrame({"b": values}, index=idx)
73+
74+
tm.assert_frame_equal(result, expected)
75+
76+
77+
@pytest.mark.parametrize("index", [0, -1])
78+
def test_nth_with_na_object(index, nulls_fixture):
79+
# https://github.com/pandas-dev/pandas/issues/32123
80+
groups = pd.DataFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, nulls_fixture]}).groupby(
81+
"a"
82+
)
83+
result = groups.nth(index)
84+
85+
if index == 0:
86+
values = [1, 3]
87+
else:
88+
values = [2, nulls_fixture]
89+
90+
values = np.array(values, dtype=result["b"].dtype)
91+
idx = pd.Index([1, 2], name="a")
92+
expected = pd.DataFrame({"b": values}, index=idx)
93+
94+
tm.assert_frame_equal(result, expected)
95+
96+
5797
def test_first_last_nth_dtypes(df_mixed_floats):
5898

5999
df = df_mixed_floats.copy()

0 commit comments

Comments
 (0)