Skip to content

Commit a71b040

Browse files
author
Lucas Kushner
committed
Adding argmax and argmin with proper behavior (#16830)
1 parent dfd9d06 commit a71b040

File tree

4 files changed

+149
-27
lines changed

4 files changed

+149
-27
lines changed

pandas/core/frame.py

+59-1
Original file line numberDiff line numberDiff line change
@@ -5329,7 +5329,7 @@ def idxmin(self, axis=0, skipna=True):
53295329

53305330
def idxmax(self, axis=0, skipna=True):
53315331
"""
5332-
Return index of first occurrence of maximum over requested axis.
5332+
Return label of first occurrence of maximum over requested axis.
53335333
NA/null values are excluded.
53345334
53355335
Parameters
@@ -5358,6 +5358,64 @@ def idxmax(self, axis=0, skipna=True):
53585358
result = [index[i] if i >= 0 else NA for i in indices]
53595359
return Series(result, index=self._get_agg_axis(axis))
53605360

5361+
def argmin(self, axis=0, skipna=True):
5362+
"""
5363+
Return index of first occurrence of minimum over requested axis.
5364+
NA/null values are excluded.
5365+
5366+
Parameters
5367+
----------
5368+
axis : {0 or 'index', 1 or 'columns'}, default 0
5369+
0 or 'index' for row-wise, 1 or 'columns' for column-wise
5370+
skipna : boolean, default True
5371+
Exclude NA/null values. If an entire row/column is NA, the result
5372+
will be NA
5373+
5374+
Returns
5375+
-------
5376+
argmin : Series
5377+
5378+
Notes
5379+
-----
5380+
This method is the DataFrame version of ``ndarray.argmin``.
5381+
5382+
See Also
5383+
--------
5384+
Series.idxmin
5385+
"""
5386+
axis = self._get_axis_number(axis)
5387+
indices = nanops.nanargmin(self.values, axis=axis, skipna=skipna)
5388+
return Series(indices, index=self._get_agg_axis(axis))
5389+
5390+
def argmax(self, axis=0, skipna=True):
5391+
"""
5392+
Return index of first occurrence of maximum over requested axis.
5393+
NA/null values are excluded.
5394+
5395+
Parameters
5396+
----------
5397+
axis : {0 or 'index', 1 or 'columns'}, default 0
5398+
0 or 'index' for row-wise, 1 or 'columns' for column-wise
5399+
skipna : boolean, default True
5400+
Exclude NA/null values. If an entire row/column is NA, the result
5401+
will be first index.
5402+
5403+
Returns
5404+
-------
5405+
argmax : Series
5406+
5407+
Notes
5408+
-----
5409+
This method is the DataFrame version of ``ndarray.argmax``.
5410+
5411+
See Also
5412+
--------
5413+
Series.argmax
5414+
"""
5415+
axis = self._get_axis_number(axis)
5416+
indices = nanops.nanargmax(self.values, axis=axis, skipna=skipna)
5417+
return Series(indices, index=self._get_agg_axis(axis))
5418+
53615419
def _get_agg_axis(self, axis_num):
53625420
""" let's be explict about this """
53635421
if axis_num == 0:

pandas/core/series.py

+56-10
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
import pandas.core.nanops as nanops
7373
import pandas.io.formats.format as fmt
7474
from pandas.util._decorators import (
75-
Appender, deprecate, deprecate_kwarg, Substitution)
75+
Appender, deprecate_kwarg, Substitution)
7676
from pandas.util._validators import validate_bool_kwarg
7777

7878
from pandas._libs import index as libindex, tslib as libts, lib, iNaT
@@ -1239,7 +1239,7 @@ def duplicated(self, keep='first'):
12391239

12401240
def idxmin(self, axis=None, skipna=True, *args, **kwargs):
12411241
"""
1242-
Index of first occurrence of minimum of values.
1242+
Label of first occurrence of minimum of values.
12431243
12441244
Parameters
12451245
----------
@@ -1259,15 +1259,14 @@ def idxmin(self, axis=None, skipna=True, *args, **kwargs):
12591259
DataFrame.idxmin
12601260
numpy.ndarray.argmin
12611261
"""
1262-
skipna = nv.validate_argmin_with_skipna(skipna, args, kwargs)
1263-
i = nanops.nanargmin(_values_from_object(self), skipna=skipna)
1262+
i = self.argmin(axis, skipna, *args, **kwargs)
12641263
if i == -1:
12651264
return np.nan
12661265
return self.index[i]
12671266

12681267
def idxmax(self, axis=None, skipna=True, *args, **kwargs):
12691268
"""
1270-
Index of first occurrence of maximum of values.
1269+
Label of first occurrence of maximum of values.
12711270
12721271
Parameters
12731272
----------
@@ -1287,15 +1286,62 @@ def idxmax(self, axis=None, skipna=True, *args, **kwargs):
12871286
DataFrame.idxmax
12881287
numpy.ndarray.argmax
12891288
"""
1290-
skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs)
1291-
i = nanops.nanargmax(_values_from_object(self), skipna=skipna)
1289+
i = self.argmax(axis, skipna, *args, **kwargs)
12921290
if i == -1:
12931291
return np.nan
12941292
return self.index[i]
12951293

1296-
# ndarray compat
1297-
argmin = deprecate('argmin', idxmin)
1298-
argmax = deprecate('argmax', idxmax)
1294+
def argmin(self, axis=None, skipna=True, *args, **kwargs):
1295+
"""
1296+
Index of first occurrence of minimum of values.
1297+
1298+
Parameters
1299+
----------
1300+
skipna : boolean, default True
1301+
Exclude NA/null values
1302+
1303+
Returns
1304+
-------
1305+
idxmin : Index of minimum of values
1306+
1307+
Notes
1308+
-----
1309+
This method is the Series version of ``ndarray.argmin``.
1310+
1311+
See Also
1312+
--------
1313+
DataFrame.argmin
1314+
numpy.ndarray.argmin
1315+
"""
1316+
skipna = nv.validate_argmin_with_skipna(skipna, args, kwargs)
1317+
i = nanops.nanargmin(_values_from_object(self), skipna=skipna)
1318+
return i
1319+
1320+
def argmax(self, axis=None, skipna=True, *args, **kwargs):
1321+
"""
1322+
Index of first occurrence of maximum of values.
1323+
1324+
Parameters
1325+
----------
1326+
skipna : boolean, default True
1327+
Exclude NA/null values
1328+
1329+
Returns
1330+
-------
1331+
idxmax : Index of maximum of values
1332+
1333+
Notes
1334+
-----
1335+
This method is the Series version of ``ndarray.argmax``.
1336+
1337+
See Also
1338+
--------
1339+
DataFrame.argmax
1340+
numpy.ndarray.argmax
1341+
"""
1342+
skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs)
1343+
i = nanops.nanargmax(_values_from_object(self), skipna=skipna)
1344+
return i
12991345

13001346
def round(self, decimals=0, *args, **kwargs):
13011347
"""

pandas/tests/frame/test_analytics.py

+28
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,34 @@ def test_idxmax(self):
10311031

10321032
pytest.raises(ValueError, frame.idxmax, axis=2)
10331033

1034+
def test_argmin(self):
1035+
frame = self.frame
1036+
frame.loc[5:10] = np.nan
1037+
frame.loc[15:20, -2:] = np.nan
1038+
for skipna in [True, False]:
1039+
for axis in [0, 1]:
1040+
for df in [frame, self.intframe]:
1041+
result = df.argmin(axis=axis, skipna=skipna)
1042+
expected = df.apply(Series.argmin, axis=axis,
1043+
skipna=skipna)
1044+
tm.assert_series_equal(result, expected)
1045+
1046+
pytest.raises(ValueError, frame.argmin, axis=2)
1047+
1048+
def test_argmax(self):
1049+
frame = self.frame
1050+
frame.loc[5:10] = np.nan
1051+
frame.loc[15:20, -2:] = np.nan
1052+
for skipna in [True, False]:
1053+
for axis in [0, 1]:
1054+
for df in [frame, self.intframe]:
1055+
result = df.argmax(axis=axis, skipna=skipna)
1056+
expected = df.apply(Series.argmax, axis=axis,
1057+
skipna=skipna)
1058+
tm.assert_series_equal(result, expected)
1059+
1060+
pytest.raises(ValueError, frame.argmax, axis=2)
1061+
10341062
# ----------------------------------------------------------------------
10351063
# Logical reductions
10361064

pandas/tests/series/test_analytics.py

+6-16
Original file line numberDiff line numberDiff line change
@@ -1212,14 +1212,9 @@ def test_idxmin(self):
12121212

12131213
def test_numpy_argmin(self):
12141214
data = np.random.randint(0, 11, size=10)
1215-
1216-
with pytest.warns(FutureWarning):
1217-
result = np.argmin(Series(data))
1218-
assert result == np.argmin(data)
1219-
1220-
with tm.assert_produces_warning(FutureWarning):
1221-
# argmin is aliased to idxmin
1222-
Series(data).argmin()
1215+
result = np.argmin(Series(data))
1216+
assert result == np.argmin(data)
1217+
assert result == Series(data).argmin()
12231218

12241219
if not _np_version_under1p10:
12251220
msg = "the 'out' parameter is not supported"
@@ -1272,14 +1267,9 @@ def test_idxmax(self):
12721267

12731268
def test_numpy_argmax(self):
12741269
data = np.random.randint(0, 11, size=10)
1275-
1276-
with pytest.warns(FutureWarning):
1277-
result = np.argmax(Series(data))
1278-
assert result == np.argmax(data)
1279-
1280-
with tm.assert_produces_warning(FutureWarning):
1281-
# argmax is aliased to idxmax
1282-
Series(data).argmax()
1270+
result = np.argmax(Series(data))
1271+
assert result == np.argmax(data)
1272+
assert result == Series(data).argmax()
12831273

12841274
if not _np_version_under1p10:
12851275
msg = "the 'out' parameter is not supported"

0 commit comments

Comments
 (0)