Skip to content

ENH: support argmin/max, idxmin/max with object dtype #54109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ Other enhancements
- Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`)
- Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`)
- Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`)
- Reductions :meth:`Series.argmax`, :meth:`Series.argmin`, :meth:`Series.idxmax`, :meth:`Series.idxmin`, :meth:`Index.argmax`, :meth:`Index.argmin`, :meth:`DataFrame.idxmax`, :meth:`DataFrame.idxmin` are now supported for object-dtype objects (:issue:`4279`, :issue:`18021`, :issue:`40685`, :issue:`43697`)
-

.. ---------------------------------------------------------------------------
Expand Down
2 changes: 0 additions & 2 deletions pandas/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,6 @@ def reduction(
nanmax = _nanminmax("max", fill_value_typ="-inf")


@disallow("O")
def nanargmax(
values: np.ndarray,
*,
Expand Down Expand Up @@ -1140,7 +1139,6 @@ def nanargmax(
return result


@disallow("O")
def nanargmin(
values: np.ndarray,
*,
Expand Down
14 changes: 6 additions & 8 deletions pandas/tests/frame/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,13 +982,12 @@ def test_idxmin_empty(self, index, skipna, axis):
@pytest.mark.parametrize("numeric_only", [True, False])
def test_idxmin_numeric_only(self, numeric_only):
df = DataFrame({"a": [2, 3, 1], "b": [2, 1, 1], "c": list("xyx")})
result = df.idxmin(numeric_only=numeric_only)
if numeric_only:
result = df.idxmin(numeric_only=numeric_only)
expected = Series([2, 1], index=["a", "b"])
tm.assert_series_equal(result, expected)
else:
with pytest.raises(TypeError, match="not allowed for this dtype"):
df.idxmin(numeric_only=numeric_only)
expected = Series([2, 1, 0], index=["a", "b", "c"])
tm.assert_series_equal(result, expected)

def test_idxmin_axis_2(self, float_frame):
frame = float_frame
Expand Down Expand Up @@ -1022,13 +1021,12 @@ def test_idxmax_empty(self, index, skipna, axis):
@pytest.mark.parametrize("numeric_only", [True, False])
def test_idxmax_numeric_only(self, numeric_only):
df = DataFrame({"a": [2, 3, 1], "b": [2, 1, 1], "c": list("xyx")})
result = df.idxmax(numeric_only=numeric_only)
if numeric_only:
result = df.idxmax(numeric_only=numeric_only)
expected = Series([1, 0], index=["a", "b"])
tm.assert_series_equal(result, expected)
else:
with pytest.raises(TypeError, match="not allowed for this dtype"):
df.idxmin(numeric_only=numeric_only)
expected = Series([1, 0, 1], index=["a", "b", "c"])
tm.assert_series_equal(result, expected)

def test_idxmax_axis_2(self, float_frame):
frame = float_frame
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/groupby/aggregate/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def test_agg_str_with_kwarg_axis_1_raises(df, reduction_func):
warn_msg = f"DataFrameGroupBy.{reduction_func} with axis=1 is deprecated"
if reduction_func in ("idxmax", "idxmin"):
error = TypeError
msg = "reduction operation '.*' not allowed for this dtype"
msg = "'[<>]' not supported between instances of 'float' and 'str'"
warn = FutureWarning
else:
error = ValueError
Expand Down
14 changes: 7 additions & 7 deletions pandas/tests/groupby/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def test_idxmin_idxmax_axis1():
df["E"] = date_range("2016-01-01", periods=10)
gb2 = df.groupby("A")

msg = "reduction operation 'argmax' not allowed for this dtype"
msg = "'>' not supported between instances of 'Timestamp' and 'float'"
with pytest.raises(TypeError, match=msg):
with tm.assert_produces_warning(FutureWarning, match=warn_msg):
gb2.idxmax(axis=1)
Expand Down Expand Up @@ -1467,7 +1467,7 @@ def test_numeric_only(kernel, has_arg, numeric_only, keys):
):
result = method(*args, **kwargs)
assert "b" in result.columns
elif has_arg or kernel in ("idxmax", "idxmin"):
elif has_arg:
assert numeric_only is not True
# kernels that are successful on any dtype were above; this will fail

Expand All @@ -1486,6 +1486,10 @@ def test_numeric_only(kernel, has_arg, numeric_only, keys):
re.escape(f"agg function failed [how->{kernel},dtype->object]"),
]
)
if kernel == "idxmin":
msg = "'<' not supported between instances of 'type' and 'type'"
elif kernel == "idxmax":
msg = "'>' not supported between instances of 'type' and 'type'"
with pytest.raises(exception, match=msg):
method(*args, **kwargs)
elif not has_arg and numeric_only is not lib.no_default:
Expand Down Expand Up @@ -1529,8 +1533,6 @@ def test_deprecate_numeric_only_series(dtype, groupby_func, request):
"cummin",
"cumprod",
"cumsum",
"idxmax",
"idxmin",
"quantile",
)
# ops that give an object result on object input
Expand All @@ -1556,9 +1558,7 @@ def test_deprecate_numeric_only_series(dtype, groupby_func, request):
# Test default behavior; kernels that fail may be enabled in the future but kernels
# that succeed should not be allowed to fail (without deprecation, at least)
if groupby_func in fails_on_numeric_object and dtype is object:
if groupby_func in ("idxmax", "idxmin"):
msg = "not allowed for this dtype"
elif groupby_func == "quantile":
if groupby_func == "quantile":
msg = "cannot be performed against 'object' dtypes"
else:
msg = "is not supported for object dtype"
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/groupby/test_raises.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def test_groupby_raises_string(
"ffill": (None, ""),
"fillna": (None, ""),
"first": (None, ""),
"idxmax": (TypeError, "'argmax' not allowed for this dtype"),
"idxmin": (TypeError, "'argmin' not allowed for this dtype"),
"idxmax": (None, ""),
"idxmin": (None, ""),
"last": (None, ""),
"max": (None, ""),
"mean": (
Expand Down
89 changes: 76 additions & 13 deletions pandas/tests/reductions/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
datetime,
timedelta,
)
from decimal import Decimal

import numpy as np
import pytest
Expand Down Expand Up @@ -1070,27 +1071,89 @@ def test_timedelta64_analytics(self):
(Series(["foo", "foo", "bar", "bar", None, np.nan, "baz"]), TypeError),
],
)
def test_assert_idxminmax_raises(self, test_input, error_type):
def test_assert_idxminmax_empty_raises(self, test_input, error_type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like error_type can be removed from the parametrization now?

"""
Cases where ``Series.argmax`` and related should raise an exception
"""
msg = (
"reduction operation 'argmin' not allowed for this dtype|"
"attempt to get argmin of an empty sequence"
)
with pytest.raises(error_type, match=msg):
test_input = Series([], dtype="float64")
msg = "attempt to get argmin of an empty sequence"
with pytest.raises(ValueError, match=msg):
test_input.idxmin()
with pytest.raises(error_type, match=msg):
with pytest.raises(ValueError, match=msg):
test_input.idxmin(skipna=False)
msg = (
"reduction operation 'argmax' not allowed for this dtype|"
"attempt to get argmax of an empty sequence"
)
with pytest.raises(error_type, match=msg):
msg = "attempt to get argmax of an empty sequence"
with pytest.raises(ValueError, match=msg):
test_input.idxmax()
with pytest.raises(error_type, match=msg):
with pytest.raises(ValueError, match=msg):
test_input.idxmax(skipna=False)

def test_idxminmax_object_dtype(self):
# pre-2.1 object-dtype was disallowed for argmin/max
ser = Series(["foo", "bar", "baz"])
assert ser.idxmax() == 0
assert ser.idxmax(skipna=False) == 0
assert ser.idxmin() == 1
assert ser.idxmin(skipna=False) == 1

ser2 = Series([(1,), (2,)])
assert ser2.idxmax() == 1
assert ser2.idxmax(skipna=False) == 1
assert ser2.idxmin() == 0
assert ser2.idxmin(skipna=False) == 0

# attempting to compare np.nan with string raises
ser3 = Series(["foo", "foo", "bar", "bar", None, np.nan, "baz"])
msg = "'>' not supported between instances of 'float' and 'str'"
with pytest.raises(TypeError, match=msg):
ser3.idxmax()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So with object dtype, na-like values are not considered "missing" and rather object scalars so skipna=True has no effect here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the skipping part is done after the np.argmin/np.argmax is done, so that can still raise

with pytest.raises(TypeError, match=msg):
ser3.idxmax(skipna=False)
msg = "'<' not supported between instances of 'float' and 'str'"
with pytest.raises(TypeError, match=msg):
ser3.idxmin()
with pytest.raises(TypeError, match=msg):
ser3.idxmin(skipna=False)

def test_idxminmax_object_frame(self):
# GH#4279
df = DataFrame([["zimm", 2.5], ["biff", 1.0], ["bid", 12.0]])
res = df.idxmax()
exp = Series([0, 2])
tm.assert_series_equal(res, exp)

def test_idxminmax_object_tuples(self):
# GH#43697
ser = Series([(1, 3), (2, 2), (3, 1)])
assert ser.idxmax() == 2
assert ser.idxmin() == 0
assert ser.idxmax(skipna=False) == 2
assert ser.idxmin(skipna=False) == 0

def test_idxminmax_object_decimals(self):
# GH#40685
df = DataFrame(
{
"idx": [0, 1],
"x": [Decimal("8.68"), Decimal("42.23")],
"y": [Decimal("7.11"), Decimal("79.61")],
}
)
res = df.idxmax()
exp = Series({"idx": 1, "x": 1, "y": 1})
tm.assert_series_equal(res, exp)

res2 = df.idxmin()
exp2 = exp - 1
tm.assert_series_equal(res2, exp2)

def test_argminmax_object_ints(self):
# GH#18021
ser = Series([0, 1], dtype="object")
assert ser.argmax() == 1
assert ser.argmin() == 0
assert ser.argmax(skipna=False) == 1
assert ser.argmin(skipna=False) == 0

def test_idxminmax_with_inf(self):
# For numeric data with NA and Inf (GH #13595)
s = Series([0, -np.inf, np.inf, np.nan])
Expand Down