Skip to content

Commit 16bbae0

Browse files
authored
ENH: Add argmax and argmin to ExtensionArray (#27801)
1 parent 27de044 commit 16bbae0

File tree

6 files changed

+124
-1
lines changed

6 files changed

+124
-1
lines changed

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ Other enhancements
280280
- :meth:`Styler.highlight_null` now accepts ``subset`` argument (:issue:`31345`)
281281
- When writing directly to a sqlite connection :func:`to_sql` now supports the ``multi`` method (:issue:`29921`)
282282
- `OptionError` is now exposed in `pandas.errors` (:issue:`27553`)
283+
- Add :meth:`ExtensionArray.argmax` and :meth:`ExtensionArray.argmin` (:issue:`24382`)
283284
- :func:`timedelta_range` will now infer a frequency when passed ``start``, ``stop``, and ``periods`` (:issue:`32377`)
284285
- Positional slicing on a :class:`IntervalIndex` now supports slices with ``step > 1`` (:issue:`31658`)
285286
- :class:`Series.str` now has a `fullmatch` method that matches a regular expression against the entire string in each row of the series, similar to `re.fullmatch` (:issue:`32806`).

pandas/core/arrays/base.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pandas.core import ops
2929
from pandas.core.algorithms import _factorize_array, unique
3030
from pandas.core.missing import backfill_1d, pad_1d
31-
from pandas.core.sorting import nargsort
31+
from pandas.core.sorting import nargminmax, nargsort
3232

3333
_extension_array_shared_docs: Dict[str, str] = dict()
3434

@@ -533,6 +533,40 @@ def argsort(
533533
result = nargsort(self, kind=kind, ascending=ascending, na_position="last")
534534
return result
535535

536+
def argmin(self):
537+
"""
538+
Return the index of minimum value.
539+
540+
In case of multiple occurrences of the minimum value, the index
541+
corresponding to the first occurrence is returned.
542+
543+
Returns
544+
-------
545+
int
546+
547+
See Also
548+
--------
549+
ExtensionArray.argmax
550+
"""
551+
return nargminmax(self, "argmin")
552+
553+
def argmax(self):
554+
"""
555+
Return the index of maximum value.
556+
557+
In case of multiple occurrences of the maximum value, the index
558+
corresponding to the first occurrence is returned.
559+
560+
Returns
561+
-------
562+
int
563+
564+
See Also
565+
--------
566+
ExtensionArray.argmin
567+
"""
568+
return nargminmax(self, "argmax")
569+
536570
def fillna(self, value=None, method=None, limit=None):
537571
"""
538572
Fill NA/NaN values using the specified method.

pandas/core/sorting.py

+27
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,33 @@ def nargsort(
319319
return indexer
320320

321321

322+
def nargminmax(values, method: str):
323+
"""
324+
Implementation of np.argmin/argmax but for ExtensionArray and which
325+
handles missing values.
326+
327+
Parameters
328+
----------
329+
values : ExtensionArray
330+
method : {"argmax", "argmin"}
331+
332+
Returns
333+
-------
334+
int
335+
"""
336+
assert method in {"argmax", "argmin"}
337+
func = np.argmax if method == "argmax" else np.argmin
338+
339+
mask = np.asarray(isna(values))
340+
values = values._values_for_argsort()
341+
342+
idx = np.arange(len(values))
343+
non_nans = values[~mask]
344+
non_nan_idx = idx[~mask]
345+
346+
return non_nan_idx[func(non_nans)]
347+
348+
322349
def ensure_key_mapped_multiindex(index, key: Callable, level=None):
323350
"""
324351
Returns a new MultiIndex in which key has been applied

pandas/tests/extension/base/methods.py

+36
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,42 @@ def test_argsort_missing(self, data_missing_for_sorting):
7575
expected = pd.Series(np.array([1, -1, 0], dtype=np.int64))
7676
self.assert_series_equal(result, expected)
7777

78+
def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting, na_value):
79+
# GH 24382
80+
81+
# data_for_sorting -> [B, C, A] with A < B < C
82+
assert data_for_sorting.argmax() == 1
83+
assert data_for_sorting.argmin() == 2
84+
85+
# with repeated values -> first occurence
86+
data = data_for_sorting.take([2, 0, 0, 1, 1, 2])
87+
assert data.argmax() == 3
88+
assert data.argmin() == 0
89+
90+
# with missing values
91+
# data_missing_for_sorting -> [B, NA, A] with A < B and NA missing.
92+
assert data_missing_for_sorting.argmax() == 0
93+
assert data_missing_for_sorting.argmin() == 2
94+
95+
@pytest.mark.parametrize(
96+
"method", ["argmax", "argmin"],
97+
)
98+
def test_argmin_argmax_empty_array(self, method, data):
99+
# GH 24382
100+
err_msg = "attempt to get"
101+
with pytest.raises(ValueError, match=err_msg):
102+
getattr(data[:0], method)()
103+
104+
@pytest.mark.parametrize(
105+
"method", ["argmax", "argmin"],
106+
)
107+
def test_argmin_argmax_all_na(self, method, data, na_value):
108+
# all missing with skipna=True is the same as emtpy
109+
err_msg = "attempt to get"
110+
data_na = type(data)._from_sequence([na_value, na_value], dtype=data.dtype)
111+
with pytest.raises(ValueError, match=err_msg):
112+
getattr(data_na, method)()
113+
78114
@pytest.mark.parametrize(
79115
"na_position, expected",
80116
[

pandas/tests/extension/test_boolean.py

+17
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,23 @@ def test_searchsorted(self, data_for_sorting, as_series):
235235
def test_value_counts(self, all_data, dropna):
236236
return super().test_value_counts(all_data, dropna)
237237

238+
def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting):
239+
# override because there are only 2 unique values
240+
241+
# data_for_sorting -> [B, C, A] with A < B < C -> here True, True, False
242+
assert data_for_sorting.argmax() == 0
243+
assert data_for_sorting.argmin() == 2
244+
245+
# with repeated values -> first occurence
246+
data = data_for_sorting.take([2, 0, 0, 1, 1, 2])
247+
assert data.argmax() == 1
248+
assert data.argmin() == 0
249+
250+
# with missing values
251+
# data_missing_for_sorting -> [B, NA, A] with A < B and NA missing.
252+
assert data_missing_for_sorting.argmax() == 0
253+
assert data_missing_for_sorting.argmin() == 2
254+
238255

239256
class TestCasting(base.BaseCastingTests):
240257
pass

pandas/tests/extension/test_sparse.py

+8
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,14 @@ def test_shift_0_periods(self, data):
321321
data._sparse_values[0] = data._sparse_values[1]
322322
assert result._sparse_values[0] != result._sparse_values[1]
323323

324+
@pytest.mark.parametrize(
325+
"method", ["argmax", "argmin"],
326+
)
327+
def test_argmin_argmax_all_na(self, method, data, na_value):
328+
# overriding because Sparse[int64, 0] cannot handle na_value
329+
self._check_unsupported(data)
330+
super().test_argmin_argmax_all_na(method, data, na_value)
331+
324332
@pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame])
325333
def test_equals(self, data, na_value, as_series, box):
326334
self._check_unsupported(data)

0 commit comments

Comments
 (0)