Skip to content

Commit 2c30be4

Browse files
rhshadrachpmhatre1
authored andcommitted
CLN/PERF: Simplify argmin/argmax (pandas-dev#58019)
* CLN/PERF: Simplify argmin/argmax * More simplifications * Partial revert * Remove comments * fixups
1 parent 258ec1b commit 2c30be4

File tree

9 files changed

+37
-51
lines changed

9 files changed

+37
-51
lines changed

pandas/core/arrays/_mixins.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,15 @@ def argmin(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[overri
210210
# override base class by adding axis keyword
211211
validate_bool_kwarg(skipna, "skipna")
212212
if not skipna and self._hasna:
213-
raise NotImplementedError
213+
raise ValueError("Encountered an NA value with skipna=False")
214214
return nargminmax(self, "argmin", axis=axis)
215215

216216
# Signature of "argmax" incompatible with supertype "ExtensionArray"
217217
def argmax(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override]
218218
# override base class by adding axis keyword
219219
validate_bool_kwarg(skipna, "skipna")
220220
if not skipna and self._hasna:
221-
raise NotImplementedError
221+
raise ValueError("Encountered an NA value with skipna=False")
222222
return nargminmax(self, "argmax", axis=axis)
223223

224224
def unique(self) -> Self:

pandas/core/arrays/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,7 @@ def argmin(self, skipna: bool = True) -> int:
885885
# 2. argmin itself : total control over sorting.
886886
validate_bool_kwarg(skipna, "skipna")
887887
if not skipna and self._hasna:
888-
raise NotImplementedError
888+
raise ValueError("Encountered an NA value with skipna=False")
889889
return nargminmax(self, "argmin")
890890

891891
def argmax(self, skipna: bool = True) -> int:
@@ -919,7 +919,7 @@ def argmax(self, skipna: bool = True) -> int:
919919
# 2. argmax itself : total control over sorting.
920920
validate_bool_kwarg(skipna, "skipna")
921921
if not skipna and self._hasna:
922-
raise NotImplementedError
922+
raise ValueError("Encountered an NA value with skipna=False")
923923
return nargminmax(self, "argmax")
924924

925925
def interpolate(

pandas/core/arrays/sparse/array.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1623,13 +1623,13 @@ def _argmin_argmax(self, kind: Literal["argmin", "argmax"]) -> int:
16231623
def argmax(self, skipna: bool = True) -> int:
16241624
validate_bool_kwarg(skipna, "skipna")
16251625
if not skipna and self._hasna:
1626-
raise NotImplementedError
1626+
raise ValueError("Encountered an NA value with skipna=False")
16271627
return self._argmin_argmax("argmax")
16281628

16291629
def argmin(self, skipna: bool = True) -> int:
16301630
validate_bool_kwarg(skipna, "skipna")
16311631
if not skipna and self._hasna:
1632-
raise NotImplementedError
1632+
raise ValueError("Encountered an NA value with skipna=False")
16331633
return self._argmin_argmax("argmin")
16341634

16351635
# ------------------------------------------------------------------------

pandas/core/base.py

+3-13
Original file line numberDiff line numberDiff line change
@@ -735,13 +735,8 @@ def argmax(
735735
nv.validate_minmax_axis(axis)
736736
skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs)
737737

738-
if skipna and len(delegate) > 0 and isna(delegate).all():
739-
raise ValueError("Encountered all NA values")
740-
elif not skipna and isna(delegate).any():
741-
raise ValueError("Encountered an NA value with skipna=False")
742-
743738
if isinstance(delegate, ExtensionArray):
744-
return delegate.argmax()
739+
return delegate.argmax(skipna=skipna)
745740
else:
746741
result = nanops.nanargmax(delegate, skipna=skipna)
747742
# error: Incompatible return value type (got "Union[int, ndarray]", expected
@@ -754,15 +749,10 @@ def argmin(
754749
) -> int:
755750
delegate = self._values
756751
nv.validate_minmax_axis(axis)
757-
skipna = nv.validate_argmin_with_skipna(skipna, args, kwargs)
758-
759-
if skipna and len(delegate) > 0 and isna(delegate).all():
760-
raise ValueError("Encountered all NA values")
761-
elif not skipna and isna(delegate).any():
762-
raise ValueError("Encountered an NA value with skipna=False")
752+
skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs)
763753

764754
if isinstance(delegate, ExtensionArray):
765-
return delegate.argmin()
755+
return delegate.argmin(skipna=skipna)
766756
else:
767757
result = nanops.nanargmin(delegate, skipna=skipna)
768758
# error: Incompatible return value type (got "Union[int, ndarray]", expected

pandas/core/indexes/base.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -6934,11 +6934,11 @@ def argmin(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:
69346934
nv.validate_minmax_axis(axis)
69356935

69366936
if not self._is_multi and self.hasnans:
6937-
# Take advantage of cache
6938-
if self._isnan.all():
6939-
raise ValueError("Encountered all NA values")
6940-
elif not skipna:
6937+
if not skipna:
69416938
raise ValueError("Encountered an NA value with skipna=False")
6939+
elif self._isnan.all():
6940+
raise ValueError("Encountered all NA values")
6941+
69426942
return super().argmin(skipna=skipna)
69436943

69446944
@Appender(IndexOpsMixin.argmax.__doc__)
@@ -6947,11 +6947,10 @@ def argmax(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:
69476947
nv.validate_minmax_axis(axis)
69486948

69496949
if not self._is_multi and self.hasnans:
6950-
# Take advantage of cache
6951-
if self._isnan.all():
6952-
raise ValueError("Encountered all NA values")
6953-
elif not skipna:
6950+
if not skipna:
69546951
raise ValueError("Encountered an NA value with skipna=False")
6952+
elif self._isnan.all():
6953+
raise ValueError("Encountered all NA values")
69556954
return super().argmax(skipna=skipna)
69566955

69576956
def min(self, axis=None, skipna: bool = True, *args, **kwargs):

pandas/core/nanops.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -1428,20 +1428,15 @@ def _maybe_arg_null_out(
14281428
return result
14291429

14301430
if axis is None or not getattr(result, "ndim", False):
1431-
if skipna:
1432-
if mask.all():
1433-
raise ValueError("Encountered all NA values")
1434-
else:
1435-
if mask.any():
1436-
raise ValueError("Encountered an NA value with skipna=False")
1431+
if skipna and mask.all():
1432+
raise ValueError("Encountered all NA values")
1433+
elif not skipna and mask.any():
1434+
raise ValueError("Encountered an NA value with skipna=False")
14371435
else:
1438-
na_mask = mask.all(axis)
1439-
if na_mask.any():
1436+
if skipna and mask.all(axis).any():
14401437
raise ValueError("Encountered all NA values")
1441-
elif not skipna:
1442-
na_mask = mask.any(axis)
1443-
if na_mask.any():
1444-
raise ValueError("Encountered an NA value with skipna=False")
1438+
elif not skipna and mask.any(axis).any():
1439+
raise ValueError("Encountered an NA value with skipna=False")
14451440
return result
14461441

14471442

pandas/tests/extension/base/methods.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,10 @@ def test_argmax_argmin_no_skipna_notimplemented(self, data_missing_for_sorting):
191191
# GH#38733
192192
data = data_missing_for_sorting
193193

194-
with pytest.raises(NotImplementedError, match=""):
194+
with pytest.raises(ValueError, match="Encountered an NA value"):
195195
data.argmin(skipna=False)
196196

197-
with pytest.raises(NotImplementedError, match=""):
197+
with pytest.raises(ValueError, match="Encountered an NA value"):
198198
data.argmax(skipna=False)
199199

200200
@pytest.mark.parametrize(

pandas/tests/frame/test_reductions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,7 @@ def test_idxmin(self, float_frame, int_frame, skipna, axis):
10661066
frame.iloc[15:20, -2:] = np.nan
10671067
for df in [frame, int_frame]:
10681068
if (not skipna or axis == 1) and df is not int_frame:
1069-
if axis == 1:
1069+
if skipna:
10701070
msg = "Encountered all NA values"
10711071
else:
10721072
msg = "Encountered an NA value"
@@ -1116,7 +1116,7 @@ def test_idxmax(self, float_frame, int_frame, skipna, axis):
11161116
frame.iloc[15:20, -2:] = np.nan
11171117
for df in [frame, int_frame]:
11181118
if (skipna is False or axis == 1) and df is frame:
1119-
if axis == 1:
1119+
if skipna:
11201120
msg = "Encountered all NA values"
11211121
else:
11221122
msg = "Encountered an NA value"

pandas/tests/reductions/test_reductions.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ def test_argminmax(self):
171171
obj.argmin()
172172
with pytest.raises(ValueError, match="Encountered all NA values"):
173173
obj.argmax()
174-
with pytest.raises(ValueError, match="Encountered all NA values"):
174+
with pytest.raises(ValueError, match="Encountered an NA value"):
175175
obj.argmin(skipna=False)
176-
with pytest.raises(ValueError, match="Encountered all NA values"):
176+
with pytest.raises(ValueError, match="Encountered an NA value"):
177177
obj.argmax(skipna=False)
178178

179179
obj = Index([NaT, datetime(2011, 11, 1), datetime(2011, 11, 2), NaT])
@@ -189,9 +189,9 @@ def test_argminmax(self):
189189
obj.argmin()
190190
with pytest.raises(ValueError, match="Encountered all NA values"):
191191
obj.argmax()
192-
with pytest.raises(ValueError, match="Encountered all NA values"):
192+
with pytest.raises(ValueError, match="Encountered an NA value"):
193193
obj.argmin(skipna=False)
194-
with pytest.raises(ValueError, match="Encountered all NA values"):
194+
with pytest.raises(ValueError, match="Encountered an NA value"):
195195
obj.argmax(skipna=False)
196196

197197
@pytest.mark.parametrize("op, expected_col", [["max", "a"], ["min", "b"]])
@@ -856,7 +856,8 @@ def test_idxmin(self):
856856

857857
# all NaNs
858858
allna = string_series * np.nan
859-
with pytest.raises(ValueError, match="Encountered all NA values"):
859+
msg = "Encountered all NA values"
860+
with pytest.raises(ValueError, match=msg):
860861
allna.idxmin()
861862

862863
# datetime64[ns]
@@ -888,7 +889,8 @@ def test_idxmax(self):
888889

889890
# all NaNs
890891
allna = string_series * np.nan
891-
with pytest.raises(ValueError, match="Encountered all NA values"):
892+
msg = "Encountered all NA values"
893+
with pytest.raises(ValueError, match=msg):
892894
allna.idxmax()
893895

894896
s = Series(date_range("20130102", periods=6))
@@ -1155,12 +1157,12 @@ def test_idxminmax_object_dtype(self, using_infer_string):
11551157
msg = "'>' not supported between instances of 'float' and 'str'"
11561158
with pytest.raises(TypeError, match=msg):
11571159
ser3.idxmax()
1158-
with pytest.raises(ValueError, match="Encountered an NA value"):
1160+
with pytest.raises(TypeError, match=msg):
11591161
ser3.idxmax(skipna=False)
11601162
msg = "'<' not supported between instances of 'float' and 'str'"
11611163
with pytest.raises(TypeError, match=msg):
11621164
ser3.idxmin()
1163-
with pytest.raises(ValueError, match="Encountered an NA value"):
1165+
with pytest.raises(TypeError, match=msg):
11641166
ser3.idxmin(skipna=False)
11651167

11661168
def test_idxminmax_object_frame(self):

0 commit comments

Comments
 (0)