diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index cee41f248fc60..38ef1115988b5 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -280,6 +280,7 @@ Other enhancements - :meth:`Styler.highlight_null` now accepts ``subset`` argument (:issue:`31345`) - When writing directly to a sqlite connection :func:`to_sql` now supports the ``multi`` method (:issue:`29921`) - `OptionError` is now exposed in `pandas.errors` (:issue:`27553`) +- Add :meth:`ExtensionArray.argmax` and :meth:`ExtensionArray.argmin` (:issue:`24382`) - :func:`timedelta_range` will now infer a frequency when passed ``start``, ``stop``, and ``periods`` (:issue:`32377`) - Positional slicing on a :class:`IntervalIndex` now supports slices with ``step > 1`` (:issue:`31658`) - :class:`Series.str` now has a `fullmatch` method that matches a regular expression against the entire string in each row of the series, similar to `re.fullmatch` (:issue:`32806`). diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 5565b85f8d59a..32a2a30fcfd43 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -28,7 +28,7 @@ from pandas.core import ops from pandas.core.algorithms import _factorize_array, unique from pandas.core.missing import backfill_1d, pad_1d -from pandas.core.sorting import nargsort +from pandas.core.sorting import nargminmax, nargsort _extension_array_shared_docs: Dict[str, str] = dict() @@ -533,6 +533,40 @@ def argsort( result = nargsort(self, kind=kind, ascending=ascending, na_position="last") return result + def argmin(self): + """ + Return the index of minimum value. + + In case of multiple occurrences of the minimum value, the index + corresponding to the first occurrence is returned. + + Returns + ------- + int + + See Also + -------- + ExtensionArray.argmax + """ + return nargminmax(self, "argmin") + + def argmax(self): + """ + Return the index of maximum value. + + In case of multiple occurrences of the maximum value, the index + corresponding to the first occurrence is returned. + + Returns + ------- + int + + See Also + -------- + ExtensionArray.argmin + """ + return nargminmax(self, "argmax") + def fillna(self, value=None, method=None, limit=None): """ Fill NA/NaN values using the specified method. diff --git a/pandas/core/sorting.py b/pandas/core/sorting.py index da9cbe1023599..ee73aa42701b0 100644 --- a/pandas/core/sorting.py +++ b/pandas/core/sorting.py @@ -319,6 +319,33 @@ def nargsort( return indexer +def nargminmax(values, method: str): + """ + Implementation of np.argmin/argmax but for ExtensionArray and which + handles missing values. + + Parameters + ---------- + values : ExtensionArray + method : {"argmax", "argmin"} + + Returns + ------- + int + """ + assert method in {"argmax", "argmin"} + func = np.argmax if method == "argmax" else np.argmin + + mask = np.asarray(isna(values)) + values = values._values_for_argsort() + + idx = np.arange(len(values)) + non_nans = values[~mask] + non_nan_idx = idx[~mask] + + return non_nan_idx[func(non_nans)] + + def ensure_key_mapped_multiindex(index, key: Callable, level=None): """ Returns a new MultiIndex in which key has been applied diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index 874a8dfd4253f..5e1cf30efd534 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -75,6 +75,42 @@ def test_argsort_missing(self, data_missing_for_sorting): expected = pd.Series(np.array([1, -1, 0], dtype=np.int64)) self.assert_series_equal(result, expected) + def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting, na_value): + # GH 24382 + + # data_for_sorting -> [B, C, A] with A < B < C + assert data_for_sorting.argmax() == 1 + assert data_for_sorting.argmin() == 2 + + # with repeated values -> first occurence + data = data_for_sorting.take([2, 0, 0, 1, 1, 2]) + assert data.argmax() == 3 + assert data.argmin() == 0 + + # with missing values + # data_missing_for_sorting -> [B, NA, A] with A < B and NA missing. + assert data_missing_for_sorting.argmax() == 0 + assert data_missing_for_sorting.argmin() == 2 + + @pytest.mark.parametrize( + "method", ["argmax", "argmin"], + ) + def test_argmin_argmax_empty_array(self, method, data): + # GH 24382 + err_msg = "attempt to get" + with pytest.raises(ValueError, match=err_msg): + getattr(data[:0], method)() + + @pytest.mark.parametrize( + "method", ["argmax", "argmin"], + ) + def test_argmin_argmax_all_na(self, method, data, na_value): + # all missing with skipna=True is the same as emtpy + err_msg = "attempt to get" + data_na = type(data)._from_sequence([na_value, na_value], dtype=data.dtype) + with pytest.raises(ValueError, match=err_msg): + getattr(data_na, method)() + @pytest.mark.parametrize( "na_position, expected", [ diff --git a/pandas/tests/extension/test_boolean.py b/pandas/tests/extension/test_boolean.py index 725067951eeef..8acbeaf0b8170 100644 --- a/pandas/tests/extension/test_boolean.py +++ b/pandas/tests/extension/test_boolean.py @@ -235,6 +235,23 @@ def test_searchsorted(self, data_for_sorting, as_series): def test_value_counts(self, all_data, dropna): return super().test_value_counts(all_data, dropna) + def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting): + # override because there are only 2 unique values + + # data_for_sorting -> [B, C, A] with A < B < C -> here True, True, False + assert data_for_sorting.argmax() == 0 + assert data_for_sorting.argmin() == 2 + + # with repeated values -> first occurence + data = data_for_sorting.take([2, 0, 0, 1, 1, 2]) + assert data.argmax() == 1 + assert data.argmin() == 0 + + # with missing values + # data_missing_for_sorting -> [B, NA, A] with A < B and NA missing. + assert data_missing_for_sorting.argmax() == 0 + assert data_missing_for_sorting.argmin() == 2 + class TestCasting(base.BaseCastingTests): pass diff --git a/pandas/tests/extension/test_sparse.py b/pandas/tests/extension/test_sparse.py index f318934ef5e52..68e521b005c02 100644 --- a/pandas/tests/extension/test_sparse.py +++ b/pandas/tests/extension/test_sparse.py @@ -321,6 +321,14 @@ def test_shift_0_periods(self, data): data._sparse_values[0] = data._sparse_values[1] assert result._sparse_values[0] != result._sparse_values[1] + @pytest.mark.parametrize( + "method", ["argmax", "argmin"], + ) + def test_argmin_argmax_all_na(self, method, data, na_value): + # overriding because Sparse[int64, 0] cannot handle na_value + self._check_unsupported(data) + super().test_argmin_argmax_all_na(method, data, na_value) + @pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame]) def test_equals(self, data, na_value, as_series, box): self._check_unsupported(data)