From 5cab58d8f527a2294d7e15421e58ade219d69a52 Mon Sep 17 00:00:00 2001 From: Richard Shadrach Date: Tue, 26 Mar 2024 17:39:37 -0400 Subject: [PATCH 1/5] CLN/PERF: Simplify argmin/argmax --- pandas/core/arrays/_mixins.py | 4 +-- pandas/core/arrays/base.py | 4 +-- pandas/core/arrays/sparse/array.py | 4 +-- pandas/core/base.py | 33 ++-------------------- pandas/core/indexes/base.py | 13 +++++---- pandas/core/nanops.py | 2 +- pandas/tests/extension/base/methods.py | 4 +-- pandas/tests/frame/test_reductions.py | 4 +-- pandas/tests/reductions/test_reductions.py | 29 +++++++++++-------- 9 files changed, 38 insertions(+), 59 deletions(-) diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 7f4e6f6666382..930ee83aea00b 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -210,7 +210,7 @@ def argmin(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[overri # override base class by adding axis keyword validate_bool_kwarg(skipna, "skipna") if not skipna and self._hasna: - raise NotImplementedError + raise ValueError("Encountered an NA value with skipna=False") return nargminmax(self, "argmin", axis=axis) # Signature of "argmax" incompatible with supertype "ExtensionArray" @@ -218,7 +218,7 @@ def argmax(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[overri # override base class by adding axis keyword validate_bool_kwarg(skipna, "skipna") if not skipna and self._hasna: - raise NotImplementedError + raise ValueError("Encountered an NA value with skipna=False") return nargminmax(self, "argmax", axis=axis) def unique(self) -> Self: diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 76615704f2e33..fdc839225a557 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -885,7 +885,7 @@ def argmin(self, skipna: bool = True) -> int: # 2. argmin itself : total control over sorting. validate_bool_kwarg(skipna, "skipna") if not skipna and self._hasna: - raise NotImplementedError + raise ValueError("Encountered an NA value with skipna=False") return nargminmax(self, "argmin") def argmax(self, skipna: bool = True) -> int: @@ -919,7 +919,7 @@ def argmax(self, skipna: bool = True) -> int: # 2. argmax itself : total control over sorting. validate_bool_kwarg(skipna, "skipna") if not skipna and self._hasna: - raise NotImplementedError + raise ValueError("Encountered an NA value with skipna=False") return nargminmax(self, "argmax") def interpolate( diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index bdcb3219a9875..2a96423017bb7 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -1623,13 +1623,13 @@ def _argmin_argmax(self, kind: Literal["argmin", "argmax"]) -> int: def argmax(self, skipna: bool = True) -> int: validate_bool_kwarg(skipna, "skipna") if not skipna and self._hasna: - raise NotImplementedError + raise ValueError("Encountered an NA value with skipna=False") return self._argmin_argmax("argmax") def argmin(self, skipna: bool = True) -> int: validate_bool_kwarg(skipna, "skipna") if not skipna and self._hasna: - raise NotImplementedError + raise ValueError("Encountered an NA value with skipna=False") return self._argmin_argmax("argmin") # ------------------------------------------------------------------------ diff --git a/pandas/core/base.py b/pandas/core/base.py index 263265701691b..b96c63300f9de 100644 --- a/pandas/core/base.py +++ b/pandas/core/base.py @@ -53,7 +53,6 @@ from pandas.core import ( algorithms, - nanops, ops, ) from pandas.core.accessor import DirNamesMixin @@ -731,43 +730,17 @@ def argmax( the minimum cereal calories is the first element, since series is zero-indexed. """ - delegate = self._values nv.validate_minmax_axis(axis) skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs) - - if skipna and len(delegate) > 0 and isna(delegate).all(): - raise ValueError("Encountered all NA values") - elif not skipna and isna(delegate).any(): - raise ValueError("Encountered an NA value with skipna=False") - - if isinstance(delegate, ExtensionArray): - return delegate.argmax() - else: - result = nanops.nanargmax(delegate, skipna=skipna) - # error: Incompatible return value type (got "Union[int, ndarray]", expected - # "int") - return result # type: ignore[return-value] + return self.array.argmax(skipna=skipna) @doc(argmax, op="min", oppose="max", value="smallest") def argmin( self, axis: AxisInt | None = None, skipna: bool = True, *args, **kwargs ) -> int: - delegate = self._values nv.validate_minmax_axis(axis) - skipna = nv.validate_argmin_with_skipna(skipna, args, kwargs) - - if skipna and len(delegate) > 0 and isna(delegate).all(): - raise ValueError("Encountered all NA values") - elif not skipna and isna(delegate).any(): - raise ValueError("Encountered an NA value with skipna=False") - - if isinstance(delegate, ExtensionArray): - return delegate.argmin() - else: - result = nanops.nanargmin(delegate, skipna=skipna) - # error: Incompatible return value type (got "Union[int, ndarray]", expected - # "int") - return result # type: ignore[return-value] + skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs) + return self.array.argmin(skipna=skipna) def tolist(self) -> list: """ diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 76dd19a9424f5..632bc93cdc479 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -6976,10 +6976,11 @@ def argmin(self, axis=None, skipna: bool = True, *args, **kwargs) -> int: if not self._is_multi and self.hasnans: # Take advantage of cache - if self._isnan.all(): - raise ValueError("Encountered all NA values") - elif not skipna: + if not skipna: raise ValueError("Encountered an NA value with skipna=False") + elif self._isnan.all(): + raise ValueError("Encountered all NA values") + return super().argmin(skipna=skipna) @Appender(IndexOpsMixin.argmax.__doc__) @@ -6989,10 +6990,10 @@ def argmax(self, axis=None, skipna: bool = True, *args, **kwargs) -> int: if not self._is_multi and self.hasnans: # Take advantage of cache - if self._isnan.all(): - raise ValueError("Encountered all NA values") - elif not skipna: + if not skipna: raise ValueError("Encountered an NA value with skipna=False") + elif self._isnan.all(): + raise ValueError("Encountered all NA values") return super().argmax(skipna=skipna) def min(self, axis=None, skipna: bool = True, *args, **kwargs): diff --git a/pandas/core/nanops.py b/pandas/core/nanops.py index b68337d9e0de9..e20c349fae2c6 100644 --- a/pandas/core/nanops.py +++ b/pandas/core/nanops.py @@ -1447,7 +1447,7 @@ def _maybe_arg_null_out( raise ValueError("Encountered an NA value with skipna=False") else: na_mask = mask.all(axis) - if na_mask.any(): + if skipna and na_mask.any(): raise ValueError("Encountered all NA values") elif not skipna: na_mask = mask.any(axis) diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index 26638c6160b7b..225a3301b8b8c 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -191,10 +191,10 @@ def test_argmax_argmin_no_skipna_notimplemented(self, data_missing_for_sorting): # GH#38733 data = data_missing_for_sorting - with pytest.raises(NotImplementedError, match=""): + with pytest.raises(ValueError, match="Encountered an NA value"): data.argmin(skipna=False) - with pytest.raises(NotImplementedError, match=""): + with pytest.raises(ValueError, match="Encountered an NA value"): data.argmax(skipna=False) @pytest.mark.parametrize( diff --git a/pandas/tests/frame/test_reductions.py b/pandas/tests/frame/test_reductions.py index 408cb0ab6fc5c..c5c7ffab9b4ae 100644 --- a/pandas/tests/frame/test_reductions.py +++ b/pandas/tests/frame/test_reductions.py @@ -1066,7 +1066,7 @@ def test_idxmin(self, float_frame, int_frame, skipna, axis): frame.iloc[15:20, -2:] = np.nan for df in [frame, int_frame]: if (not skipna or axis == 1) and df is not int_frame: - if axis == 1: + if skipna: msg = "Encountered all NA values" else: msg = "Encountered an NA value" @@ -1116,7 +1116,7 @@ def test_idxmax(self, float_frame, int_frame, skipna, axis): frame.iloc[15:20, -2:] = np.nan for df in [frame, int_frame]: if (skipna is False or axis == 1) and df is frame: - if axis == 1: + if skipna: msg = "Encountered all NA values" else: msg = "Encountered an NA value" diff --git a/pandas/tests/reductions/test_reductions.py b/pandas/tests/reductions/test_reductions.py index b10319f5380e7..e570daa306215 100644 --- a/pandas/tests/reductions/test_reductions.py +++ b/pandas/tests/reductions/test_reductions.py @@ -171,9 +171,9 @@ def test_argminmax(self): obj.argmin() with pytest.raises(ValueError, match="Encountered all NA values"): obj.argmax() - with pytest.raises(ValueError, match="Encountered all NA values"): + with pytest.raises(ValueError, match="Encountered an NA value"): obj.argmin(skipna=False) - with pytest.raises(ValueError, match="Encountered all NA values"): + with pytest.raises(ValueError, match="Encountered an NA value"): obj.argmax(skipna=False) obj = Index([NaT, datetime(2011, 11, 1), datetime(2011, 11, 2), NaT]) @@ -189,9 +189,9 @@ def test_argminmax(self): obj.argmin() with pytest.raises(ValueError, match="Encountered all NA values"): obj.argmax() - with pytest.raises(ValueError, match="Encountered all NA values"): + with pytest.raises(ValueError, match="Encountered an NA value"): obj.argmin(skipna=False) - with pytest.raises(ValueError, match="Encountered all NA values"): + with pytest.raises(ValueError, match="Encountered an NA value"): obj.argmax(skipna=False) @pytest.mark.parametrize("op, expected_col", [["max", "a"], ["min", "b"]]) @@ -856,7 +856,8 @@ def test_idxmin(self): # all NaNs allna = string_series * np.nan - with pytest.raises(ValueError, match="Encountered all NA values"): + msg = "attempt to get argmin of an empty sequence" + with pytest.raises(ValueError, match=msg): allna.idxmin() # datetime64[ns] @@ -888,7 +889,8 @@ def test_idxmax(self): # all NaNs allna = string_series * np.nan - with pytest.raises(ValueError, match="Encountered all NA values"): + msg = "attempt to get argmin of an empty sequence" + with pytest.raises(ValueError, match=msg): allna.idxmax() s = Series(date_range("20130102", periods=6)) @@ -1143,14 +1145,17 @@ def test_idxminmax_object_dtype(self, using_infer_string): if not using_infer_string: # attempting to compare np.nan with string raises ser3 = Series(["foo", "foo", "bar", "bar", None, np.nan, "baz"]) - msg = "'>' not supported between instances of 'float' and 'str'" - with pytest.raises(TypeError, match=msg): - ser3.idxmax() + result = ser3.idxmax() + expected = 0 + assert result == expected + with pytest.raises(ValueError, match="Encountered an NA value"): ser3.idxmax(skipna=False) - msg = "'<' not supported between instances of 'float' and 'str'" - with pytest.raises(TypeError, match=msg): - ser3.idxmin() + + result = ser3.idxmin() + expected = 2 + assert result == expected + with pytest.raises(ValueError, match="Encountered an NA value"): ser3.idxmin(skipna=False) From 05171484d08a9ba6f1a8c6deca55b9266375a183 Mon Sep 17 00:00:00 2001 From: Richard Shadrach Date: Tue, 26 Mar 2024 17:46:40 -0400 Subject: [PATCH 2/5] More simplifications --- pandas/core/nanops.py | 19 +++++++------------ pandas/tests/reductions/test_reductions.py | 2 +- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/pandas/core/nanops.py b/pandas/core/nanops.py index e20c349fae2c6..623d61a9b2ea9 100644 --- a/pandas/core/nanops.py +++ b/pandas/core/nanops.py @@ -1439,20 +1439,15 @@ def _maybe_arg_null_out( return result if axis is None or not getattr(result, "ndim", False): - if skipna: - if mask.all(): - raise ValueError("Encountered all NA values") - else: - if mask.any(): - raise ValueError("Encountered an NA value with skipna=False") + if skipna and mask.all(): + raise ValueError("Encountered all NA values") + elif not skipna and mask.any(): + raise ValueError("Encountered an NA value with skipna=False") else: - na_mask = mask.all(axis) - if skipna and na_mask.any(): + if skipna and mask.all(axis).any(): raise ValueError("Encountered all NA values") - elif not skipna: - na_mask = mask.any(axis) - if na_mask.any(): - raise ValueError("Encountered an NA value with skipna=False") + elif not skipna and mask.any(axis).any(): + raise ValueError("Encountered an NA value with skipna=False") return result diff --git a/pandas/tests/reductions/test_reductions.py b/pandas/tests/reductions/test_reductions.py index e570daa306215..a758c79d83311 100644 --- a/pandas/tests/reductions/test_reductions.py +++ b/pandas/tests/reductions/test_reductions.py @@ -889,7 +889,7 @@ def test_idxmax(self): # all NaNs allna = string_series * np.nan - msg = "attempt to get argmin of an empty sequence" + msg = "attempt to get argmax of an empty sequence" with pytest.raises(ValueError, match=msg): allna.idxmax() From 7ad5f94871bd0b37c7382d6709499ddc9bd16d82 Mon Sep 17 00:00:00 2001 From: Richard Shadrach Date: Tue, 26 Mar 2024 18:17:54 -0400 Subject: [PATCH 3/5] Partial revert --- asv_bench/asv.conf.json | 1 + pandas/core/base.py | 23 +++++++++++++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index e02ff26ba14e9..dc58921cc4a87 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -42,6 +42,7 @@ // followed by the pip installed packages). "matrix": { "Cython": ["3.0"], + "build": [], "matplotlib": [], "sqlalchemy": [], "scipy": [], diff --git a/pandas/core/base.py b/pandas/core/base.py index b96c63300f9de..6df311f3153da 100644 --- a/pandas/core/base.py +++ b/pandas/core/base.py @@ -53,6 +53,7 @@ from pandas.core import ( algorithms, + nanops, ops, ) from pandas.core.accessor import DirNamesMixin @@ -730,17 +731,31 @@ def argmax( the minimum cereal calories is the first element, since series is zero-indexed. """ + delegate = self._values nv.validate_minmax_axis(axis) - skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs) - return self.array.argmax(skipna=skipna) + + if isinstance(delegate, ExtensionArray): + return delegate.argmax() + else: + result = nanops.nanargmax(delegate, skipna=skipna) + # error: Incompatible return value type (got "Union[int, ndarray]", expected + # "int") + return result # type: ignore[return-value] @doc(argmax, op="min", oppose="max", value="smallest") def argmin( self, axis: AxisInt | None = None, skipna: bool = True, *args, **kwargs ) -> int: + delegate = self._values nv.validate_minmax_axis(axis) - skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs) - return self.array.argmin(skipna=skipna) + + if isinstance(delegate, ExtensionArray): + return delegate.argmin() + else: + result = nanops.nanargmin(delegate, skipna=skipna) + # error: Incompatible return value type (got "Union[int, ndarray]", expected + # "int") + return result # type: ignore[return-value] def tolist(self) -> list: """ From c9bd1f855d4f7657de49766386a332ffacfabf89 Mon Sep 17 00:00:00 2001 From: Richard Shadrach Date: Tue, 26 Mar 2024 18:27:00 -0400 Subject: [PATCH 4/5] Remove comments --- asv_bench/asv.conf.json | 1 - pandas/core/indexes/base.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index dc58921cc4a87..e02ff26ba14e9 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -42,7 +42,6 @@ // followed by the pip installed packages). "matrix": { "Cython": ["3.0"], - "build": [], "matplotlib": [], "sqlalchemy": [], "scipy": [], diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 632bc93cdc479..c57c7d1fe1232 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -6975,7 +6975,6 @@ def argmin(self, axis=None, skipna: bool = True, *args, **kwargs) -> int: nv.validate_minmax_axis(axis) if not self._is_multi and self.hasnans: - # Take advantage of cache if not skipna: raise ValueError("Encountered an NA value with skipna=False") elif self._isnan.all(): @@ -6989,7 +6988,6 @@ def argmax(self, axis=None, skipna: bool = True, *args, **kwargs) -> int: nv.validate_minmax_axis(axis) if not self._is_multi and self.hasnans: - # Take advantage of cache if not skipna: raise ValueError("Encountered an NA value with skipna=False") elif self._isnan.all(): From 46bbd5ebc7064ec2afd671de6e040719608c1a59 Mon Sep 17 00:00:00 2001 From: richard Date: Sat, 30 Mar 2024 08:46:25 -0400 Subject: [PATCH 5/5] fixups --- pandas/core/base.py | 6 ++++-- pandas/tests/reductions/test_reductions.py | 23 ++++++++++------------ 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/pandas/core/base.py b/pandas/core/base.py index 6df311f3153da..0dffc0254c550 100644 --- a/pandas/core/base.py +++ b/pandas/core/base.py @@ -733,9 +733,10 @@ def argmax( """ delegate = self._values nv.validate_minmax_axis(axis) + skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs) if isinstance(delegate, ExtensionArray): - return delegate.argmax() + return delegate.argmax(skipna=skipna) else: result = nanops.nanargmax(delegate, skipna=skipna) # error: Incompatible return value type (got "Union[int, ndarray]", expected @@ -748,9 +749,10 @@ def argmin( ) -> int: delegate = self._values nv.validate_minmax_axis(axis) + skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs) if isinstance(delegate, ExtensionArray): - return delegate.argmin() + return delegate.argmin(skipna=skipna) else: result = nanops.nanargmin(delegate, skipna=skipna) # error: Incompatible return value type (got "Union[int, ndarray]", expected diff --git a/pandas/tests/reductions/test_reductions.py b/pandas/tests/reductions/test_reductions.py index a758c79d83311..726ed4ad8a399 100644 --- a/pandas/tests/reductions/test_reductions.py +++ b/pandas/tests/reductions/test_reductions.py @@ -856,7 +856,7 @@ def test_idxmin(self): # all NaNs allna = string_series * np.nan - msg = "attempt to get argmin of an empty sequence" + msg = "Encountered all NA values" with pytest.raises(ValueError, match=msg): allna.idxmin() @@ -889,7 +889,7 @@ def test_idxmax(self): # all NaNs allna = string_series * np.nan - msg = "attempt to get argmax of an empty sequence" + msg = "Encountered all NA values" with pytest.raises(ValueError, match=msg): allna.idxmax() @@ -1145,18 +1145,15 @@ def test_idxminmax_object_dtype(self, using_infer_string): if not using_infer_string: # attempting to compare np.nan with string raises ser3 = Series(["foo", "foo", "bar", "bar", None, np.nan, "baz"]) - result = ser3.idxmax() - expected = 0 - assert result == expected - - with pytest.raises(ValueError, match="Encountered an NA value"): + msg = "'>' not supported between instances of 'float' and 'str'" + with pytest.raises(TypeError, match=msg): + ser3.idxmax() + with pytest.raises(TypeError, match=msg): ser3.idxmax(skipna=False) - - result = ser3.idxmin() - expected = 2 - assert result == expected - - with pytest.raises(ValueError, match="Encountered an NA value"): + msg = "'<' not supported between instances of 'float' and 'str'" + with pytest.raises(TypeError, match=msg): + ser3.idxmin() + with pytest.raises(TypeError, match=msg): ser3.idxmin(skipna=False) def test_idxminmax_object_frame(self):