Skip to content

Commit 7aa710a

Browse files
TYP: disallow decorator preserves function signature (#33521)
1 parent fd3e5a1 commit 7aa710a

File tree

2 files changed

+46
-14
lines changed

2 files changed

+46
-14
lines changed

pandas/core/frame.py

+12
Original file line numberDiff line numberDiff line change
@@ -8522,6 +8522,12 @@ def idxmin(self, axis=0, skipna=True) -> Series:
85228522
"""
85238523
axis = self._get_axis_number(axis)
85248524
indices = nanops.nanargmin(self.values, axis=axis, skipna=skipna)
8525+
8526+
# indices will always be np.ndarray since axis is not None and
8527+
# values is a 2d array for DataFrame
8528+
# error: Item "int" of "Union[int, Any]" has no attribute "__iter__"
8529+
assert isinstance(indices, np.ndarray) # for mypy
8530+
85258531
index = self._get_axis(axis)
85268532
result = [index[i] if i >= 0 else np.nan for i in indices]
85278533
return Series(result, index=self._get_agg_axis(axis))
@@ -8589,6 +8595,12 @@ def idxmax(self, axis=0, skipna=True) -> Series:
85898595
"""
85908596
axis = self._get_axis_number(axis)
85918597
indices = nanops.nanargmax(self.values, axis=axis, skipna=skipna)
8598+
8599+
# indices will always be np.ndarray since axis is not None and
8600+
# values is a 2d array for DataFrame
8601+
# error: Item "int" of "Union[int, Any]" has no attribute "__iter__"
8602+
assert isinstance(indices, np.ndarray) # for mypy
8603+
85928604
index = self._get_axis(axis)
85938605
result = [index[i] if i >= 0 else np.nan for i in indices]
85948606
return Series(result, index=self._get_agg_axis(axis))

pandas/core/nanops.py

+34-14
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import functools
22
import itertools
33
import operator
4-
from typing import Any, Optional, Tuple, Union
4+
from typing import Any, Optional, Tuple, Union, cast
55

66
import numpy as np
77

88
from pandas._config import get_option
99

1010
from pandas._libs import NaT, Timedelta, Timestamp, iNaT, lib
11-
from pandas._typing import ArrayLike, Dtype, Scalar
11+
from pandas._typing import ArrayLike, Dtype, F, Scalar
1212
from pandas.compat._optional import import_optional_dependency
1313

1414
from pandas.core.dtypes.cast import _int64_max, maybe_upcast_putmask
@@ -57,7 +57,7 @@ def __init__(self, *dtypes):
5757
def check(self, obj) -> bool:
5858
return hasattr(obj, "dtype") and issubclass(obj.dtype.type, self.dtypes)
5959

60-
def __call__(self, f):
60+
def __call__(self, f: F) -> F:
6161
@functools.wraps(f)
6262
def _f(*args, **kwargs):
6363
obj_iter = itertools.chain(args, kwargs.values())
@@ -78,7 +78,7 @@ def _f(*args, **kwargs):
7878
raise TypeError(e) from e
7979
raise
8080

81-
return _f
81+
return cast(F, _f)
8282

8383

8484
class bottleneck_switch:
@@ -878,7 +878,7 @@ def nanargmax(
878878
axis: Optional[int] = None,
879879
skipna: bool = True,
880880
mask: Optional[np.ndarray] = None,
881-
) -> int:
881+
) -> Union[int, np.ndarray]:
882882
"""
883883
Parameters
884884
----------
@@ -890,15 +890,25 @@ def nanargmax(
890890
891891
Returns
892892
-------
893-
result : int
894-
The index of max value in specified axis or -1 in the NA case
893+
result : int or ndarray[int]
894+
The index/indices of max value in specified axis or -1 in the NA case
895895
896896
Examples
897897
--------
898898
>>> import pandas.core.nanops as nanops
899-
>>> s = pd.Series([1, 2, 3, np.nan, 4])
900-
>>> nanops.nanargmax(s)
899+
>>> arr = np.array([1, 2, 3, np.nan, 4])
900+
>>> nanops.nanargmax(arr)
901901
4
902+
903+
>>> arr = np.array(range(12), dtype=np.float64).reshape(4, 3)
904+
>>> arr[2:, 2] = np.nan
905+
>>> arr
906+
array([[ 0., 1., 2.],
907+
[ 3., 4., 5.],
908+
[ 6., 7., nan],
909+
[ 9., 10., nan]])
910+
>>> nanops.nanargmax(arr, axis=1)
911+
array([2, 2, 1, 1], dtype=int64)
902912
"""
903913
values, mask, dtype, _, _ = _get_values(
904914
values, True, fill_value_typ="-inf", mask=mask
@@ -914,7 +924,7 @@ def nanargmin(
914924
axis: Optional[int] = None,
915925
skipna: bool = True,
916926
mask: Optional[np.ndarray] = None,
917-
) -> int:
927+
) -> Union[int, np.ndarray]:
918928
"""
919929
Parameters
920930
----------
@@ -926,15 +936,25 @@ def nanargmin(
926936
927937
Returns
928938
-------
929-
result : int
930-
The index of min value in specified axis or -1 in the NA case
939+
result : int or ndarray[int]
940+
The index/indices of min value in specified axis or -1 in the NA case
931941
932942
Examples
933943
--------
934944
>>> import pandas.core.nanops as nanops
935-
>>> s = pd.Series([1, 2, 3, np.nan, 4])
936-
>>> nanops.nanargmin(s)
945+
>>> arr = np.array([1, 2, 3, np.nan, 4])
946+
>>> nanops.nanargmin(arr)
937947
0
948+
949+
>>> arr = np.array(range(12), dtype=np.float64).reshape(4, 3)
950+
>>> arr[2:, 0] = np.nan
951+
>>> arr
952+
array([[ 0., 1., 2.],
953+
[ 3., 4., 5.],
954+
[nan, 7., 8.],
955+
[nan, 10., 11.]])
956+
>>> nanops.nanargmin(arr, axis=1)
957+
array([0, 0, 1, 1], dtype=int64)
938958
"""
939959
values, mask, dtype, _, _ = _get_values(
940960
values, True, fill_value_typ="+inf", mask=mask

0 commit comments

Comments
 (0)