diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 6a85bfd852e19..a249d0cf39a0b 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -339,6 +339,7 @@ Groupby/resample/rolling - Fixed bug in :meth:`DataFrameGroupBy.sum` and :meth:`SeriesGroupBy.sum` causing loss of precision through using Kahan summation (:issue:`38778`) - Fixed bug in :meth:`DataFrameGroupBy.cumsum`, :meth:`SeriesGroupBy.cumsum`, :meth:`DataFrameGroupBy.mean` and :meth:`SeriesGroupBy.mean` causing loss of precision through using Kahan summation (:issue:`38934`) - Bug in :meth:`.Resampler.aggregate` and :meth:`DataFrame.transform` raising ``TypeError`` instead of ``SpecificationError`` when missing keys had mixed dtypes (:issue:`39025`) +- Bug in :meth:`.DataFrameGroupBy.idxmin` and :meth:`.DataFrameGroupBy.idxmax` with ``ExtensionDtype`` columns (:issue:`38733`) Reshaping ^^^^^^^^^ diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 9a8b37e0785e0..b0979218e099c 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -30,7 +30,7 @@ from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError from pandas.util._decorators import Appender, Substitution -from pandas.util._validators import validate_fillna_kwargs +from pandas.util._validators import validate_bool_kwarg, validate_fillna_kwargs from pandas.core.dtypes.cast import maybe_cast_to_extension_array from pandas.core.dtypes.common import ( @@ -596,13 +596,17 @@ def argsort( mask=np.asarray(self.isna()), ) - def argmin(self): + def argmin(self, skipna: bool = True) -> int: """ Return the index of minimum value. In case of multiple occurrences of the minimum value, the index corresponding to the first occurrence is returned. + Parameters + ---------- + skipna : bool, default True + Returns ------- int @@ -611,15 +615,22 @@ def argmin(self): -------- ExtensionArray.argmax """ + validate_bool_kwarg(skipna, "skipna") + if not skipna and self.isna().any(): + raise NotImplementedError return nargminmax(self, "argmin") - def argmax(self): + def argmax(self, skipna: bool = True) -> int: """ Return the index of maximum value. In case of multiple occurrences of the maximum value, the index corresponding to the first occurrence is returned. + Parameters + ---------- + skipna : bool, default True + Returns ------- int @@ -628,6 +639,9 @@ def argmax(self): -------- ExtensionArray.argmin """ + validate_bool_kwarg(skipna, "skipna") + if not skipna and self.isna().any(): + raise NotImplementedError return nargminmax(self, "argmax") def fillna(self, value=None, method=None, limit=None): diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index 7e7f1f1a6e025..3518f3b29e8c2 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -128,6 +128,16 @@ def test_argreduce_series( result = getattr(ser, op_name)(skipna=skipna) tm.assert_almost_equal(result, expected) + def test_argmax_argmin_no_skipna_notimplemented(self, data_missing_for_sorting): + # GH#38733 + data = data_missing_for_sorting + + with pytest.raises(NotImplementedError, match=""): + data.argmin(skipna=False) + + with pytest.raises(NotImplementedError, match=""): + data.argmax(skipna=False) + @pytest.mark.parametrize( "na_position, expected", [ diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index 8d7fcbfcfe694..f532e496ccca9 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -531,10 +531,16 @@ def test_idxmin_idxmax_returns_int_types(func, values): } ) df["c_date"] = pd.to_datetime(df["c_date"]) + df["c_date_tz"] = df["c_date"].dt.tz_localize("US/Pacific") + df["c_timedelta"] = df["c_date"] - df["c_date"].iloc[0] + df["c_period"] = df["c_date"].dt.to_period("W") result = getattr(df.groupby("name"), func)() expected = DataFrame(values, index=Index(["A", "B"], name="name")) + expected["c_date_tz"] = expected["c_date"] + expected["c_timedelta"] = expected["c_date"] + expected["c_period"] = expected["c_date"] tm.assert_frame_equal(result, expected)