From 775ddd3d020e1e2a6265e0c3bc14008d73875fcf Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 13 Jul 2023 10:48:14 -0700 Subject: [PATCH] ENH: support argmin/max, idxmin/max with object dtype --- doc/source/whatsnew/v2.1.0.rst | 1 + pandas/core/nanops.py | 2 - pandas/tests/frame/test_reductions.py | 14 ++- .../tests/groupby/aggregate/test_aggregate.py | 2 +- pandas/tests/groupby/test_function.py | 14 +-- pandas/tests/groupby/test_raises.py | 4 +- pandas/tests/reductions/test_reductions.py | 89 ++++++++++++++++--- 7 files changed, 93 insertions(+), 33 deletions(-) diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index 30b2f03dec98c..9000c83b323df 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -128,6 +128,7 @@ Other enhancements - Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`) - Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`) - Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`) +- Reductions :meth:`Series.argmax`, :meth:`Series.argmin`, :meth:`Series.idxmax`, :meth:`Series.idxmin`, :meth:`Index.argmax`, :meth:`Index.argmin`, :meth:`DataFrame.idxmax`, :meth:`DataFrame.idxmin` are now supported for object-dtype objects (:issue:`4279`, :issue:`18021`, :issue:`40685`, :issue:`43697`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/nanops.py b/pandas/core/nanops.py index 59520350e0dc1..467e66bcbda31 100644 --- a/pandas/core/nanops.py +++ b/pandas/core/nanops.py @@ -1094,7 +1094,6 @@ def reduction( nanmax = _nanminmax("max", fill_value_typ="-inf") -@disallow("O") def nanargmax( values: np.ndarray, *, @@ -1140,7 +1139,6 @@ def nanargmax( return result -@disallow("O") def nanargmin( values: np.ndarray, *, diff --git a/pandas/tests/frame/test_reductions.py b/pandas/tests/frame/test_reductions.py index b4a4324593d22..6636850ce20f1 100644 --- a/pandas/tests/frame/test_reductions.py +++ b/pandas/tests/frame/test_reductions.py @@ -982,13 +982,12 @@ def test_idxmin_empty(self, index, skipna, axis): @pytest.mark.parametrize("numeric_only", [True, False]) def test_idxmin_numeric_only(self, numeric_only): df = DataFrame({"a": [2, 3, 1], "b": [2, 1, 1], "c": list("xyx")}) + result = df.idxmin(numeric_only=numeric_only) if numeric_only: - result = df.idxmin(numeric_only=numeric_only) expected = Series([2, 1], index=["a", "b"]) - tm.assert_series_equal(result, expected) else: - with pytest.raises(TypeError, match="not allowed for this dtype"): - df.idxmin(numeric_only=numeric_only) + expected = Series([2, 1, 0], index=["a", "b", "c"]) + tm.assert_series_equal(result, expected) def test_idxmin_axis_2(self, float_frame): frame = float_frame @@ -1022,13 +1021,12 @@ def test_idxmax_empty(self, index, skipna, axis): @pytest.mark.parametrize("numeric_only", [True, False]) def test_idxmax_numeric_only(self, numeric_only): df = DataFrame({"a": [2, 3, 1], "b": [2, 1, 1], "c": list("xyx")}) + result = df.idxmax(numeric_only=numeric_only) if numeric_only: - result = df.idxmax(numeric_only=numeric_only) expected = Series([1, 0], index=["a", "b"]) - tm.assert_series_equal(result, expected) else: - with pytest.raises(TypeError, match="not allowed for this dtype"): - df.idxmin(numeric_only=numeric_only) + expected = Series([1, 0, 1], index=["a", "b", "c"]) + tm.assert_series_equal(result, expected) def test_idxmax_axis_2(self, float_frame): frame = float_frame diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index 2875e1ae80501..666bd98869482 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -235,7 +235,7 @@ def test_agg_str_with_kwarg_axis_1_raises(df, reduction_func): warn_msg = f"DataFrameGroupBy.{reduction_func} with axis=1 is deprecated" if reduction_func in ("idxmax", "idxmin"): error = TypeError - msg = "reduction operation '.*' not allowed for this dtype" + msg = "'[<>]' not supported between instances of 'float' and 'str'" warn = FutureWarning else: error = ValueError diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index e3a5d308c4346..04a699cfede56 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -534,7 +534,7 @@ def test_idxmin_idxmax_axis1(): df["E"] = date_range("2016-01-01", periods=10) gb2 = df.groupby("A") - msg = "reduction operation 'argmax' not allowed for this dtype" + msg = "'>' not supported between instances of 'Timestamp' and 'float'" with pytest.raises(TypeError, match=msg): with tm.assert_produces_warning(FutureWarning, match=warn_msg): gb2.idxmax(axis=1) @@ -1467,7 +1467,7 @@ def test_numeric_only(kernel, has_arg, numeric_only, keys): ): result = method(*args, **kwargs) assert "b" in result.columns - elif has_arg or kernel in ("idxmax", "idxmin"): + elif has_arg: assert numeric_only is not True # kernels that are successful on any dtype were above; this will fail @@ -1486,6 +1486,10 @@ def test_numeric_only(kernel, has_arg, numeric_only, keys): re.escape(f"agg function failed [how->{kernel},dtype->object]"), ] ) + if kernel == "idxmin": + msg = "'<' not supported between instances of 'type' and 'type'" + elif kernel == "idxmax": + msg = "'>' not supported between instances of 'type' and 'type'" with pytest.raises(exception, match=msg): method(*args, **kwargs) elif not has_arg and numeric_only is not lib.no_default: @@ -1529,8 +1533,6 @@ def test_deprecate_numeric_only_series(dtype, groupby_func, request): "cummin", "cumprod", "cumsum", - "idxmax", - "idxmin", "quantile", ) # ops that give an object result on object input @@ -1556,9 +1558,7 @@ def test_deprecate_numeric_only_series(dtype, groupby_func, request): # Test default behavior; kernels that fail may be enabled in the future but kernels # that succeed should not be allowed to fail (without deprecation, at least) if groupby_func in fails_on_numeric_object and dtype is object: - if groupby_func in ("idxmax", "idxmin"): - msg = "not allowed for this dtype" - elif groupby_func == "quantile": + if groupby_func == "quantile": msg = "cannot be performed against 'object' dtypes" else: msg = "is not supported for object dtype" diff --git a/pandas/tests/groupby/test_raises.py b/pandas/tests/groupby/test_raises.py index a3fa5bf794030..f9a2b3d44b117 100644 --- a/pandas/tests/groupby/test_raises.py +++ b/pandas/tests/groupby/test_raises.py @@ -157,8 +157,8 @@ def test_groupby_raises_string( "ffill": (None, ""), "fillna": (None, ""), "first": (None, ""), - "idxmax": (TypeError, "'argmax' not allowed for this dtype"), - "idxmin": (TypeError, "'argmin' not allowed for this dtype"), + "idxmax": (None, ""), + "idxmin": (None, ""), "last": (None, ""), "max": (None, ""), "mean": ( diff --git a/pandas/tests/reductions/test_reductions.py b/pandas/tests/reductions/test_reductions.py index 83b9a83c0a6a2..11e3d905dc1d2 100644 --- a/pandas/tests/reductions/test_reductions.py +++ b/pandas/tests/reductions/test_reductions.py @@ -2,6 +2,7 @@ datetime, timedelta, ) +from decimal import Decimal import numpy as np import pytest @@ -1070,27 +1071,89 @@ def test_timedelta64_analytics(self): (Series(["foo", "foo", "bar", "bar", None, np.nan, "baz"]), TypeError), ], ) - def test_assert_idxminmax_raises(self, test_input, error_type): + def test_assert_idxminmax_empty_raises(self, test_input, error_type): """ Cases where ``Series.argmax`` and related should raise an exception """ - msg = ( - "reduction operation 'argmin' not allowed for this dtype|" - "attempt to get argmin of an empty sequence" - ) - with pytest.raises(error_type, match=msg): + test_input = Series([], dtype="float64") + msg = "attempt to get argmin of an empty sequence" + with pytest.raises(ValueError, match=msg): test_input.idxmin() - with pytest.raises(error_type, match=msg): + with pytest.raises(ValueError, match=msg): test_input.idxmin(skipna=False) - msg = ( - "reduction operation 'argmax' not allowed for this dtype|" - "attempt to get argmax of an empty sequence" - ) - with pytest.raises(error_type, match=msg): + msg = "attempt to get argmax of an empty sequence" + with pytest.raises(ValueError, match=msg): test_input.idxmax() - with pytest.raises(error_type, match=msg): + with pytest.raises(ValueError, match=msg): test_input.idxmax(skipna=False) + def test_idxminmax_object_dtype(self): + # pre-2.1 object-dtype was disallowed for argmin/max + ser = Series(["foo", "bar", "baz"]) + assert ser.idxmax() == 0 + assert ser.idxmax(skipna=False) == 0 + assert ser.idxmin() == 1 + assert ser.idxmin(skipna=False) == 1 + + ser2 = Series([(1,), (2,)]) + assert ser2.idxmax() == 1 + assert ser2.idxmax(skipna=False) == 1 + assert ser2.idxmin() == 0 + assert ser2.idxmin(skipna=False) == 0 + + # attempting to compare np.nan with string raises + ser3 = Series(["foo", "foo", "bar", "bar", None, np.nan, "baz"]) + msg = "'>' not supported between instances of 'float' and 'str'" + with pytest.raises(TypeError, match=msg): + ser3.idxmax() + with pytest.raises(TypeError, match=msg): + ser3.idxmax(skipna=False) + msg = "'<' not supported between instances of 'float' and 'str'" + with pytest.raises(TypeError, match=msg): + ser3.idxmin() + with pytest.raises(TypeError, match=msg): + ser3.idxmin(skipna=False) + + def test_idxminmax_object_frame(self): + # GH#4279 + df = DataFrame([["zimm", 2.5], ["biff", 1.0], ["bid", 12.0]]) + res = df.idxmax() + exp = Series([0, 2]) + tm.assert_series_equal(res, exp) + + def test_idxminmax_object_tuples(self): + # GH#43697 + ser = Series([(1, 3), (2, 2), (3, 1)]) + assert ser.idxmax() == 2 + assert ser.idxmin() == 0 + assert ser.idxmax(skipna=False) == 2 + assert ser.idxmin(skipna=False) == 0 + + def test_idxminmax_object_decimals(self): + # GH#40685 + df = DataFrame( + { + "idx": [0, 1], + "x": [Decimal("8.68"), Decimal("42.23")], + "y": [Decimal("7.11"), Decimal("79.61")], + } + ) + res = df.idxmax() + exp = Series({"idx": 1, "x": 1, "y": 1}) + tm.assert_series_equal(res, exp) + + res2 = df.idxmin() + exp2 = exp - 1 + tm.assert_series_equal(res2, exp2) + + def test_argminmax_object_ints(self): + # GH#18021 + ser = Series([0, 1], dtype="object") + assert ser.argmax() == 1 + assert ser.argmin() == 0 + assert ser.argmax(skipna=False) == 1 + assert ser.argmin(skipna=False) == 0 + def test_idxminmax_with_inf(self): # For numeric data with NA and Inf (GH #13595) s = Series([0, -np.inf, np.inf, np.nan])