Skip to content

Commit a2db6fd

Browse files
authored
BUG: GroupBy.idxmax/idxmin with EA dtypes (#38733)
1 parent bf3375c commit a2db6fd

File tree

4 files changed

+34
-3
lines changed

4 files changed

+34
-3
lines changed

doc/source/whatsnew/v1.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ Groupby/resample/rolling
340340
- Fixed bug in :meth:`DataFrameGroupBy.sum` and :meth:`SeriesGroupBy.sum` causing loss of precision through using Kahan summation (:issue:`38778`)
341341
- Fixed bug in :meth:`DataFrameGroupBy.cumsum`, :meth:`SeriesGroupBy.cumsum`, :meth:`DataFrameGroupBy.mean` and :meth:`SeriesGroupBy.mean` causing loss of precision through using Kahan summation (:issue:`38934`)
342342
- Bug in :meth:`.Resampler.aggregate` and :meth:`DataFrame.transform` raising ``TypeError`` instead of ``SpecificationError`` when missing keys had mixed dtypes (:issue:`39025`)
343+
- Bug in :meth:`.DataFrameGroupBy.idxmin` and :meth:`.DataFrameGroupBy.idxmax` with ``ExtensionDtype`` columns (:issue:`38733`)
343344

344345
Reshaping
345346
^^^^^^^^^

pandas/core/arrays/base.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pandas.compat.numpy import function as nv
3131
from pandas.errors import AbstractMethodError
3232
from pandas.util._decorators import Appender, Substitution
33-
from pandas.util._validators import validate_fillna_kwargs
33+
from pandas.util._validators import validate_bool_kwarg, validate_fillna_kwargs
3434

3535
from pandas.core.dtypes.cast import maybe_cast_to_extension_array
3636
from pandas.core.dtypes.common import (
@@ -596,13 +596,17 @@ def argsort(
596596
mask=np.asarray(self.isna()),
597597
)
598598

599-
def argmin(self):
599+
def argmin(self, skipna: bool = True) -> int:
600600
"""
601601
Return the index of minimum value.
602602
603603
In case of multiple occurrences of the minimum value, the index
604604
corresponding to the first occurrence is returned.
605605
606+
Parameters
607+
----------
608+
skipna : bool, default True
609+
606610
Returns
607611
-------
608612
int
@@ -611,15 +615,22 @@ def argmin(self):
611615
--------
612616
ExtensionArray.argmax
613617
"""
618+
validate_bool_kwarg(skipna, "skipna")
619+
if not skipna and self.isna().any():
620+
raise NotImplementedError
614621
return nargminmax(self, "argmin")
615622

616-
def argmax(self):
623+
def argmax(self, skipna: bool = True) -> int:
617624
"""
618625
Return the index of maximum value.
619626
620627
In case of multiple occurrences of the maximum value, the index
621628
corresponding to the first occurrence is returned.
622629
630+
Parameters
631+
----------
632+
skipna : bool, default True
633+
623634
Returns
624635
-------
625636
int
@@ -628,6 +639,9 @@ def argmax(self):
628639
--------
629640
ExtensionArray.argmin
630641
"""
642+
validate_bool_kwarg(skipna, "skipna")
643+
if not skipna and self.isna().any():
644+
raise NotImplementedError
631645
return nargminmax(self, "argmax")
632646

633647
def fillna(self, value=None, method=None, limit=None):

pandas/tests/extension/base/methods.py

+10
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@ def test_argreduce_series(
128128
result = getattr(ser, op_name)(skipna=skipna)
129129
tm.assert_almost_equal(result, expected)
130130

131+
def test_argmax_argmin_no_skipna_notimplemented(self, data_missing_for_sorting):
132+
# GH#38733
133+
data = data_missing_for_sorting
134+
135+
with pytest.raises(NotImplementedError, match=""):
136+
data.argmin(skipna=False)
137+
138+
with pytest.raises(NotImplementedError, match=""):
139+
data.argmax(skipna=False)
140+
131141
@pytest.mark.parametrize(
132142
"na_position, expected",
133143
[

pandas/tests/groupby/test_function.py

+6
Original file line numberDiff line numberDiff line change
@@ -531,10 +531,16 @@ def test_idxmin_idxmax_returns_int_types(func, values):
531531
}
532532
)
533533
df["c_date"] = pd.to_datetime(df["c_date"])
534+
df["c_date_tz"] = df["c_date"].dt.tz_localize("US/Pacific")
535+
df["c_timedelta"] = df["c_date"] - df["c_date"].iloc[0]
536+
df["c_period"] = df["c_date"].dt.to_period("W")
534537

535538
result = getattr(df.groupby("name"), func)()
536539

537540
expected = DataFrame(values, index=Index(["A", "B"], name="name"))
541+
expected["c_date_tz"] = expected["c_date"]
542+
expected["c_timedelta"] = expected["c_date"]
543+
expected["c_period"] = expected["c_date"]
538544

539545
tm.assert_frame_equal(result, expected)
540546

0 commit comments

Comments
 (0)