Skip to content

Commit 2f1c1f8

Browse files
author
Lucas Kushner
committed
Adding argmax and argmin with proper behavior (pandas-dev#16830)
1 parent 2784c3f commit 2f1c1f8

File tree

3 files changed

+117
-11
lines changed

3 files changed

+117
-11
lines changed

pandas/core/frame.py

+59-1
Original file line numberDiff line numberDiff line change
@@ -5329,7 +5329,7 @@ def idxmin(self, axis=0, skipna=True):
53295329

53305330
def idxmax(self, axis=0, skipna=True):
53315331
"""
5332-
Return index of first occurrence of maximum over requested axis.
5332+
Return label of first occurrence of maximum over requested axis.
53335333
NA/null values are excluded.
53345334
53355335
Parameters
@@ -5358,6 +5358,64 @@ def idxmax(self, axis=0, skipna=True):
53585358
result = [index[i] if i >= 0 else NA for i in indices]
53595359
return Series(result, index=self._get_agg_axis(axis))
53605360

5361+
def argmin(self, axis=0, skipna=True):
5362+
"""
5363+
Return index of first occurrence of minimum over requested axis.
5364+
NA/null values are excluded.
5365+
5366+
Parameters
5367+
----------
5368+
axis : {0 or 'index', 1 or 'columns'}, default 0
5369+
0 or 'index' for row-wise, 1 or 'columns' for column-wise
5370+
skipna : boolean, default True
5371+
Exclude NA/null values. If an entire row/column is NA, the result
5372+
will be NA
5373+
5374+
Returns
5375+
-------
5376+
argmin : Series
5377+
5378+
Notes
5379+
-----
5380+
This method is the DataFrame version of ``ndarray.argmin``.
5381+
5382+
See Also
5383+
--------
5384+
Series.idxmin
5385+
"""
5386+
axis = self._get_axis_number(axis)
5387+
indices = nanops.nanargmin(self.values, axis=axis, skipna=skipna)
5388+
return Series(indices, index=self._get_agg_axis(axis))
5389+
5390+
def argmax(self, axis=0, skipna=True):
5391+
"""
5392+
Return index of first occurrence of maximum over requested axis.
5393+
NA/null values are excluded.
5394+
5395+
Parameters
5396+
----------
5397+
axis : {0 or 'index', 1 or 'columns'}, default 0
5398+
0 or 'index' for row-wise, 1 or 'columns' for column-wise
5399+
skipna : boolean, default True
5400+
Exclude NA/null values. If an entire row/column is NA, the result
5401+
will be first index.
5402+
5403+
Returns
5404+
-------
5405+
argmax : Series
5406+
5407+
Notes
5408+
-----
5409+
This method is the DataFrame version of ``ndarray.argmax``.
5410+
5411+
See Also
5412+
--------
5413+
Series.argmax
5414+
"""
5415+
axis = self._get_axis_number(axis)
5416+
indices = nanops.nanargmax(self.values, axis=axis, skipna=skipna)
5417+
return Series(indices, index=self._get_agg_axis(axis))
5418+
53615419
def _get_agg_axis(self, axis_num):
53625420
""" let's be explict about this """
53635421
if axis_num == 0:

pandas/core/series.py

+56-10
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
import pandas.core.nanops as nanops
7373
import pandas.io.formats.format as fmt
7474
from pandas.util._decorators import (
75-
Appender, deprecate, deprecate_kwarg, Substitution)
75+
Appender, deprecate_kwarg, Substitution)
7676
from pandas.util._validators import validate_bool_kwarg
7777

7878
from pandas._libs import index as libindex, tslib as libts, lib, iNaT
@@ -1239,7 +1239,7 @@ def duplicated(self, keep='first'):
12391239

12401240
def idxmin(self, axis=None, skipna=True, *args, **kwargs):
12411241
"""
1242-
Index of first occurrence of minimum of values.
1242+
Label of first occurrence of minimum of values.
12431243
12441244
Parameters
12451245
----------
@@ -1259,15 +1259,14 @@ def idxmin(self, axis=None, skipna=True, *args, **kwargs):
12591259
DataFrame.idxmin
12601260
numpy.ndarray.argmin
12611261
"""
1262-
skipna = nv.validate_argmin_with_skipna(skipna, args, kwargs)
1263-
i = nanops.nanargmin(_values_from_object(self), skipna=skipna)
1262+
i = self.argmin(axis, skipna, *args, **kwargs)
12641263
if i == -1:
12651264
return np.nan
12661265
return self.index[i]
12671266

12681267
def idxmax(self, axis=None, skipna=True, *args, **kwargs):
12691268
"""
1270-
Index of first occurrence of maximum of values.
1269+
Label of first occurrence of maximum of values.
12711270
12721271
Parameters
12731272
----------
@@ -1287,15 +1286,62 @@ def idxmax(self, axis=None, skipna=True, *args, **kwargs):
12871286
DataFrame.idxmax
12881287
numpy.ndarray.argmax
12891288
"""
1290-
skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs)
1291-
i = nanops.nanargmax(_values_from_object(self), skipna=skipna)
1289+
i = self.argmax(axis, skipna, *args, **kwargs)
12921290
if i == -1:
12931291
return np.nan
12941292
return self.index[i]
12951293

1296-
# ndarray compat
1297-
argmin = deprecate('argmin', idxmin)
1298-
argmax = deprecate('argmax', idxmax)
1294+
def argmin(self, axis=None, skipna=True, *args, **kwargs):
1295+
"""
1296+
Index of first occurrence of minimum of values.
1297+
1298+
Parameters
1299+
----------
1300+
skipna : boolean, default True
1301+
Exclude NA/null values
1302+
1303+
Returns
1304+
-------
1305+
idxmin : Index of minimum of values
1306+
1307+
Notes
1308+
-----
1309+
This method is the Series version of ``ndarray.argmin``.
1310+
1311+
See Also
1312+
--------
1313+
DataFrame.argmin
1314+
numpy.ndarray.argmin
1315+
"""
1316+
skipna = nv.validate_argmin_with_skipna(skipna, args, kwargs)
1317+
i = nanops.nanargmin(_values_from_object(self), skipna=skipna)
1318+
return i
1319+
1320+
def argmax(self, axis=None, skipna=True, *args, **kwargs):
1321+
"""
1322+
Index of first occurrence of maximum of values.
1323+
1324+
Parameters
1325+
----------
1326+
skipna : boolean, default True
1327+
Exclude NA/null values
1328+
1329+
Returns
1330+
-------
1331+
idxmax : Index of maximum of values
1332+
1333+
Notes
1334+
-----
1335+
This method is the Series version of ``ndarray.argmax``.
1336+
1337+
See Also
1338+
--------
1339+
DataFrame.argmax
1340+
numpy.ndarray.argmax
1341+
"""
1342+
skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs)
1343+
i = nanops.nanargmax(_values_from_object(self), skipna=skipna)
1344+
return i
12991345

13001346
def round(self, decimals=0, *args, **kwargs):
13011347
"""

pandas/tests/series/test_analytics.py

+2
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,7 @@ def test_numpy_argmin(self):
12151215
data = np.random.randint(0, 11, size=10)
12161216
result = np.argmin(Series(data))
12171217
assert result == np.argmin(data)
1218+
assert result == Series(data).argmin()
12181219

12191220
if not _np_version_under1p10:
12201221
msg = "the 'out' parameter is not supported"
@@ -1271,6 +1272,7 @@ def test_numpy_argmax(self):
12711272
data = np.random.randint(0, 11, size=10)
12721273
result = np.argmax(Series(data))
12731274
assert result == np.argmax(data)
1275+
assert result == Series(data).argmax()
12741276

12751277
if not _np_version_under1p10:
12761278
msg = "the 'out' parameter is not supported"

0 commit comments

Comments
 (0)