diff --git a/pandas/core/frame.py b/pandas/core/frame.py index f8cb99e2b2e75..06b6d4242e622 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -8522,6 +8522,12 @@ def idxmin(self, axis=0, skipna=True) -> Series: """ axis = self._get_axis_number(axis) indices = nanops.nanargmin(self.values, axis=axis, skipna=skipna) + + # indices will always be np.ndarray since axis is not None and + # values is a 2d array for DataFrame + # error: Item "int" of "Union[int, Any]" has no attribute "__iter__" + assert isinstance(indices, np.ndarray) # for mypy + index = self._get_axis(axis) result = [index[i] if i >= 0 else np.nan for i in indices] return Series(result, index=self._get_agg_axis(axis)) @@ -8589,6 +8595,12 @@ def idxmax(self, axis=0, skipna=True) -> Series: """ axis = self._get_axis_number(axis) indices = nanops.nanargmax(self.values, axis=axis, skipna=skipna) + + # indices will always be np.ndarray since axis is not None and + # values is a 2d array for DataFrame + # error: Item "int" of "Union[int, Any]" has no attribute "__iter__" + assert isinstance(indices, np.ndarray) # for mypy + index = self._get_axis(axis) result = [index[i] if i >= 0 else np.nan for i in indices] return Series(result, index=self._get_agg_axis(axis)) diff --git a/pandas/core/nanops.py b/pandas/core/nanops.py index 32b05872ded3f..d74a6ba605666 100644 --- a/pandas/core/nanops.py +++ b/pandas/core/nanops.py @@ -1,14 +1,14 @@ import functools import itertools import operator -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union, cast import numpy as np from pandas._config import get_option from pandas._libs import NaT, Timedelta, Timestamp, iNaT, lib -from pandas._typing import ArrayLike, Dtype, Scalar +from pandas._typing import ArrayLike, Dtype, F, Scalar from pandas.compat._optional import import_optional_dependency from pandas.core.dtypes.cast import _int64_max, maybe_upcast_putmask @@ -57,7 +57,7 @@ def __init__(self, *dtypes): def check(self, obj) -> bool: return hasattr(obj, "dtype") and issubclass(obj.dtype.type, self.dtypes) - def __call__(self, f): + def __call__(self, f: F) -> F: @functools.wraps(f) def _f(*args, **kwargs): obj_iter = itertools.chain(args, kwargs.values()) @@ -78,7 +78,7 @@ def _f(*args, **kwargs): raise TypeError(e) from e raise - return _f + return cast(F, _f) class bottleneck_switch: @@ -878,7 +878,7 @@ def nanargmax( axis: Optional[int] = None, skipna: bool = True, mask: Optional[np.ndarray] = None, -) -> int: +) -> Union[int, np.ndarray]: """ Parameters ---------- @@ -890,15 +890,25 @@ def nanargmax( Returns ------- - result : int - The index of max value in specified axis or -1 in the NA case + result : int or ndarray[int] + The index/indices of max value in specified axis or -1 in the NA case Examples -------- >>> import pandas.core.nanops as nanops - >>> s = pd.Series([1, 2, 3, np.nan, 4]) - >>> nanops.nanargmax(s) + >>> arr = np.array([1, 2, 3, np.nan, 4]) + >>> nanops.nanargmax(arr) 4 + + >>> arr = np.array(range(12), dtype=np.float64).reshape(4, 3) + >>> arr[2:, 2] = np.nan + >>> arr + array([[ 0., 1., 2.], + [ 3., 4., 5.], + [ 6., 7., nan], + [ 9., 10., nan]]) + >>> nanops.nanargmax(arr, axis=1) + array([2, 2, 1, 1], dtype=int64) """ values, mask, dtype, _, _ = _get_values( values, True, fill_value_typ="-inf", mask=mask @@ -914,7 +924,7 @@ def nanargmin( axis: Optional[int] = None, skipna: bool = True, mask: Optional[np.ndarray] = None, -) -> int: +) -> Union[int, np.ndarray]: """ Parameters ---------- @@ -926,15 +936,25 @@ def nanargmin( Returns ------- - result : int - The index of min value in specified axis or -1 in the NA case + result : int or ndarray[int] + The index/indices of min value in specified axis or -1 in the NA case Examples -------- >>> import pandas.core.nanops as nanops - >>> s = pd.Series([1, 2, 3, np.nan, 4]) - >>> nanops.nanargmin(s) + >>> arr = np.array([1, 2, 3, np.nan, 4]) + >>> nanops.nanargmin(arr) 0 + + >>> arr = np.array(range(12), dtype=np.float64).reshape(4, 3) + >>> arr[2:, 0] = np.nan + >>> arr + array([[ 0., 1., 2.], + [ 3., 4., 5.], + [nan, 7., 8.], + [nan, 10., 11.]]) + >>> nanops.nanargmin(arr, axis=1) + array([0, 0, 1, 1], dtype=int64) """ values, mask, dtype, _, _ = _get_values( values, True, fill_value_typ="+inf", mask=mask