Skip to content

Commit 394af8e

Browse files
authored
ENH: support argmin/max, idxmin/max with object dtype (#54109)
1 parent 0d91d09 commit 394af8e

File tree

7 files changed

+93
-33
lines changed

7 files changed

+93
-33
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ Other enhancements
171171
- Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`)
172172
- Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`)
173173
- Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`)
174+
- Reductions :meth:`Series.argmax`, :meth:`Series.argmin`, :meth:`Series.idxmax`, :meth:`Series.idxmin`, :meth:`Index.argmax`, :meth:`Index.argmin`, :meth:`DataFrame.idxmax`, :meth:`DataFrame.idxmin` are now supported for object-dtype objects (:issue:`4279`, :issue:`18021`, :issue:`40685`, :issue:`43697`)
174175
-
175176

176177
.. ---------------------------------------------------------------------------

pandas/core/nanops.py

-2
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,6 @@ def reduction(
11001100
nanmax = _nanminmax("max", fill_value_typ="-inf")
11011101

11021102

1103-
@disallow("O")
11041103
def nanargmax(
11051104
values: np.ndarray,
11061105
*,
@@ -1146,7 +1145,6 @@ def nanargmax(
11461145
return result
11471146

11481147

1149-
@disallow("O")
11501148
def nanargmin(
11511149
values: np.ndarray,
11521150
*,

pandas/tests/frame/test_reductions.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -987,13 +987,12 @@ def test_idxmin_empty(self, index, skipna, axis):
987987
@pytest.mark.parametrize("numeric_only", [True, False])
988988
def test_idxmin_numeric_only(self, numeric_only):
989989
df = DataFrame({"a": [2, 3, 1], "b": [2, 1, 1], "c": list("xyx")})
990+
result = df.idxmin(numeric_only=numeric_only)
990991
if numeric_only:
991-
result = df.idxmin(numeric_only=numeric_only)
992992
expected = Series([2, 1], index=["a", "b"])
993-
tm.assert_series_equal(result, expected)
994993
else:
995-
with pytest.raises(TypeError, match="not allowed for this dtype"):
996-
df.idxmin(numeric_only=numeric_only)
994+
expected = Series([2, 1, 0], index=["a", "b", "c"])
995+
tm.assert_series_equal(result, expected)
997996

998997
def test_idxmin_axis_2(self, float_frame):
999998
frame = float_frame
@@ -1027,13 +1026,12 @@ def test_idxmax_empty(self, index, skipna, axis):
10271026
@pytest.mark.parametrize("numeric_only", [True, False])
10281027
def test_idxmax_numeric_only(self, numeric_only):
10291028
df = DataFrame({"a": [2, 3, 1], "b": [2, 1, 1], "c": list("xyx")})
1029+
result = df.idxmax(numeric_only=numeric_only)
10301030
if numeric_only:
1031-
result = df.idxmax(numeric_only=numeric_only)
10321031
expected = Series([1, 0], index=["a", "b"])
1033-
tm.assert_series_equal(result, expected)
10341032
else:
1035-
with pytest.raises(TypeError, match="not allowed for this dtype"):
1036-
df.idxmin(numeric_only=numeric_only)
1033+
expected = Series([1, 0, 1], index=["a", "b", "c"])
1034+
tm.assert_series_equal(result, expected)
10371035

10381036
def test_idxmax_axis_2(self, float_frame):
10391037
frame = float_frame

pandas/tests/groupby/aggregate/test_aggregate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def test_agg_str_with_kwarg_axis_1_raises(df, reduction_func):
235235
warn_msg = f"DataFrameGroupBy.{reduction_func} with axis=1 is deprecated"
236236
if reduction_func in ("idxmax", "idxmin"):
237237
error = TypeError
238-
msg = "reduction operation '.*' not allowed for this dtype"
238+
msg = "'[<>]' not supported between instances of 'float' and 'str'"
239239
warn = FutureWarning
240240
else:
241241
error = ValueError

pandas/tests/groupby/test_function.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ def test_idxmin_idxmax_axis1():
534534
df["E"] = date_range("2016-01-01", periods=10)
535535
gb2 = df.groupby("A")
536536

537-
msg = "reduction operation 'argmax' not allowed for this dtype"
537+
msg = "'>' not supported between instances of 'Timestamp' and 'float'"
538538
with pytest.raises(TypeError, match=msg):
539539
with tm.assert_produces_warning(FutureWarning, match=warn_msg):
540540
gb2.idxmax(axis=1)
@@ -1467,7 +1467,7 @@ def test_numeric_only(kernel, has_arg, numeric_only, keys):
14671467
):
14681468
result = method(*args, **kwargs)
14691469
assert "b" in result.columns
1470-
elif has_arg or kernel in ("idxmax", "idxmin"):
1470+
elif has_arg:
14711471
assert numeric_only is not True
14721472
# kernels that are successful on any dtype were above; this will fail
14731473

@@ -1486,6 +1486,10 @@ def test_numeric_only(kernel, has_arg, numeric_only, keys):
14861486
re.escape(f"agg function failed [how->{kernel},dtype->object]"),
14871487
]
14881488
)
1489+
if kernel == "idxmin":
1490+
msg = "'<' not supported between instances of 'type' and 'type'"
1491+
elif kernel == "idxmax":
1492+
msg = "'>' not supported between instances of 'type' and 'type'"
14891493
with pytest.raises(exception, match=msg):
14901494
method(*args, **kwargs)
14911495
elif not has_arg and numeric_only is not lib.no_default:
@@ -1529,8 +1533,6 @@ def test_deprecate_numeric_only_series(dtype, groupby_func, request):
15291533
"cummin",
15301534
"cumprod",
15311535
"cumsum",
1532-
"idxmax",
1533-
"idxmin",
15341536
"quantile",
15351537
)
15361538
# ops that give an object result on object input
@@ -1556,9 +1558,7 @@ def test_deprecate_numeric_only_series(dtype, groupby_func, request):
15561558
# Test default behavior; kernels that fail may be enabled in the future but kernels
15571559
# that succeed should not be allowed to fail (without deprecation, at least)
15581560
if groupby_func in fails_on_numeric_object and dtype is object:
1559-
if groupby_func in ("idxmax", "idxmin"):
1560-
msg = "not allowed for this dtype"
1561-
elif groupby_func == "quantile":
1561+
if groupby_func == "quantile":
15621562
msg = "cannot be performed against 'object' dtypes"
15631563
else:
15641564
msg = "is not supported for object dtype"

pandas/tests/groupby/test_raises.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ def test_groupby_raises_string(
157157
"ffill": (None, ""),
158158
"fillna": (None, ""),
159159
"first": (None, ""),
160-
"idxmax": (TypeError, "'argmax' not allowed for this dtype"),
161-
"idxmin": (TypeError, "'argmin' not allowed for this dtype"),
160+
"idxmax": (None, ""),
161+
"idxmin": (None, ""),
162162
"last": (None, ""),
163163
"max": (None, ""),
164164
"mean": (

pandas/tests/reductions/test_reductions.py

+76-13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
datetime,
33
timedelta,
44
)
5+
from decimal import Decimal
56

67
import numpy as np
78
import pytest
@@ -1070,27 +1071,89 @@ def test_timedelta64_analytics(self):
10701071
(Series(["foo", "foo", "bar", "bar", None, np.nan, "baz"]), TypeError),
10711072
],
10721073
)
1073-
def test_assert_idxminmax_raises(self, test_input, error_type):
1074+
def test_assert_idxminmax_empty_raises(self, test_input, error_type):
10741075
"""
10751076
Cases where ``Series.argmax`` and related should raise an exception
10761077
"""
1077-
msg = (
1078-
"reduction operation 'argmin' not allowed for this dtype|"
1079-
"attempt to get argmin of an empty sequence"
1080-
)
1081-
with pytest.raises(error_type, match=msg):
1078+
test_input = Series([], dtype="float64")
1079+
msg = "attempt to get argmin of an empty sequence"
1080+
with pytest.raises(ValueError, match=msg):
10821081
test_input.idxmin()
1083-
with pytest.raises(error_type, match=msg):
1082+
with pytest.raises(ValueError, match=msg):
10841083
test_input.idxmin(skipna=False)
1085-
msg = (
1086-
"reduction operation 'argmax' not allowed for this dtype|"
1087-
"attempt to get argmax of an empty sequence"
1088-
)
1089-
with pytest.raises(error_type, match=msg):
1084+
msg = "attempt to get argmax of an empty sequence"
1085+
with pytest.raises(ValueError, match=msg):
10901086
test_input.idxmax()
1091-
with pytest.raises(error_type, match=msg):
1087+
with pytest.raises(ValueError, match=msg):
10921088
test_input.idxmax(skipna=False)
10931089

1090+
def test_idxminmax_object_dtype(self):
1091+
# pre-2.1 object-dtype was disallowed for argmin/max
1092+
ser = Series(["foo", "bar", "baz"])
1093+
assert ser.idxmax() == 0
1094+
assert ser.idxmax(skipna=False) == 0
1095+
assert ser.idxmin() == 1
1096+
assert ser.idxmin(skipna=False) == 1
1097+
1098+
ser2 = Series([(1,), (2,)])
1099+
assert ser2.idxmax() == 1
1100+
assert ser2.idxmax(skipna=False) == 1
1101+
assert ser2.idxmin() == 0
1102+
assert ser2.idxmin(skipna=False) == 0
1103+
1104+
# attempting to compare np.nan with string raises
1105+
ser3 = Series(["foo", "foo", "bar", "bar", None, np.nan, "baz"])
1106+
msg = "'>' not supported between instances of 'float' and 'str'"
1107+
with pytest.raises(TypeError, match=msg):
1108+
ser3.idxmax()
1109+
with pytest.raises(TypeError, match=msg):
1110+
ser3.idxmax(skipna=False)
1111+
msg = "'<' not supported between instances of 'float' and 'str'"
1112+
with pytest.raises(TypeError, match=msg):
1113+
ser3.idxmin()
1114+
with pytest.raises(TypeError, match=msg):
1115+
ser3.idxmin(skipna=False)
1116+
1117+
def test_idxminmax_object_frame(self):
1118+
# GH#4279
1119+
df = DataFrame([["zimm", 2.5], ["biff", 1.0], ["bid", 12.0]])
1120+
res = df.idxmax()
1121+
exp = Series([0, 2])
1122+
tm.assert_series_equal(res, exp)
1123+
1124+
def test_idxminmax_object_tuples(self):
1125+
# GH#43697
1126+
ser = Series([(1, 3), (2, 2), (3, 1)])
1127+
assert ser.idxmax() == 2
1128+
assert ser.idxmin() == 0
1129+
assert ser.idxmax(skipna=False) == 2
1130+
assert ser.idxmin(skipna=False) == 0
1131+
1132+
def test_idxminmax_object_decimals(self):
1133+
# GH#40685
1134+
df = DataFrame(
1135+
{
1136+
"idx": [0, 1],
1137+
"x": [Decimal("8.68"), Decimal("42.23")],
1138+
"y": [Decimal("7.11"), Decimal("79.61")],
1139+
}
1140+
)
1141+
res = df.idxmax()
1142+
exp = Series({"idx": 1, "x": 1, "y": 1})
1143+
tm.assert_series_equal(res, exp)
1144+
1145+
res2 = df.idxmin()
1146+
exp2 = exp - 1
1147+
tm.assert_series_equal(res2, exp2)
1148+
1149+
def test_argminmax_object_ints(self):
1150+
# GH#18021
1151+
ser = Series([0, 1], dtype="object")
1152+
assert ser.argmax() == 1
1153+
assert ser.argmin() == 0
1154+
assert ser.argmax(skipna=False) == 1
1155+
assert ser.argmin(skipna=False) == 0
1156+
10941157
def test_idxminmax_with_inf(self):
10951158
# For numeric data with NA and Inf (GH #13595)
10961159
s = Series([0, -np.inf, np.inf, np.nan])

0 commit comments

Comments
 (0)