1
1
import functools
2
2
import itertools
3
3
import operator
4
- from typing import Any , Optional , Tuple , Union
4
+ from typing import Any , Optional , Tuple , Union , cast
5
5
6
6
import numpy as np
7
7
8
8
from pandas ._config import get_option
9
9
10
10
from pandas ._libs import NaT , Timedelta , Timestamp , iNaT , lib
11
- from pandas ._typing import ArrayLike , Dtype , Scalar
11
+ from pandas ._typing import ArrayLike , Dtype , F , Scalar
12
12
from pandas .compat ._optional import import_optional_dependency
13
13
14
14
from pandas .core .dtypes .cast import _int64_max , maybe_upcast_putmask
@@ -57,7 +57,7 @@ def __init__(self, *dtypes):
57
57
def check (self , obj ) -> bool :
58
58
return hasattr (obj , "dtype" ) and issubclass (obj .dtype .type , self .dtypes )
59
59
60
- def __call__ (self , f ) :
60
+ def __call__ (self , f : F ) -> F :
61
61
@functools .wraps (f )
62
62
def _f (* args , ** kwargs ):
63
63
obj_iter = itertools .chain (args , kwargs .values ())
@@ -78,7 +78,7 @@ def _f(*args, **kwargs):
78
78
raise TypeError (e ) from e
79
79
raise
80
80
81
- return _f
81
+ return cast ( F , _f )
82
82
83
83
84
84
class bottleneck_switch :
@@ -878,7 +878,7 @@ def nanargmax(
878
878
axis : Optional [int ] = None ,
879
879
skipna : bool = True ,
880
880
mask : Optional [np .ndarray ] = None ,
881
- ) -> int :
881
+ ) -> Union [ int , np . ndarray ] :
882
882
"""
883
883
Parameters
884
884
----------
@@ -890,15 +890,25 @@ def nanargmax(
890
890
891
891
Returns
892
892
-------
893
- result : int
894
- The index of max value in specified axis or -1 in the NA case
893
+ result : int or ndarray[int]
894
+ The index/indices of max value in specified axis or -1 in the NA case
895
895
896
896
Examples
897
897
--------
898
898
>>> import pandas.core.nanops as nanops
899
- >>> s = pd.Series ([1, 2, 3, np.nan, 4])
900
- >>> nanops.nanargmax(s )
899
+ >>> arr = np.array ([1, 2, 3, np.nan, 4])
900
+ >>> nanops.nanargmax(arr )
901
901
4
902
+
903
+ >>> arr = np.array(range(12), dtype=np.float64).reshape(4, 3)
904
+ >>> arr[2:, 2] = np.nan
905
+ >>> arr
906
+ array([[ 0., 1., 2.],
907
+ [ 3., 4., 5.],
908
+ [ 6., 7., nan],
909
+ [ 9., 10., nan]])
910
+ >>> nanops.nanargmax(arr, axis=1)
911
+ array([2, 2, 1, 1], dtype=int64)
902
912
"""
903
913
values , mask , dtype , _ , _ = _get_values (
904
914
values , True , fill_value_typ = "-inf" , mask = mask
@@ -914,7 +924,7 @@ def nanargmin(
914
924
axis : Optional [int ] = None ,
915
925
skipna : bool = True ,
916
926
mask : Optional [np .ndarray ] = None ,
917
- ) -> int :
927
+ ) -> Union [ int , np . ndarray ] :
918
928
"""
919
929
Parameters
920
930
----------
@@ -926,15 +936,25 @@ def nanargmin(
926
936
927
937
Returns
928
938
-------
929
- result : int
930
- The index of min value in specified axis or -1 in the NA case
939
+ result : int or ndarray[int]
940
+ The index/indices of min value in specified axis or -1 in the NA case
931
941
932
942
Examples
933
943
--------
934
944
>>> import pandas.core.nanops as nanops
935
- >>> s = pd.Series ([1, 2, 3, np.nan, 4])
936
- >>> nanops.nanargmin(s )
945
+ >>> arr = np.array ([1, 2, 3, np.nan, 4])
946
+ >>> nanops.nanargmin(arr )
937
947
0
948
+
949
+ >>> arr = np.array(range(12), dtype=np.float64).reshape(4, 3)
950
+ >>> arr[2:, 0] = np.nan
951
+ >>> arr
952
+ array([[ 0., 1., 2.],
953
+ [ 3., 4., 5.],
954
+ [nan, 7., 8.],
955
+ [nan, 10., 11.]])
956
+ >>> nanops.nanargmin(arr, axis=1)
957
+ array([0, 0, 1, 1], dtype=int64)
938
958
"""
939
959
values , mask , dtype , _ , _ = _get_values (
940
960
values , True , fill_value_typ = "+inf" , mask = mask
0 commit comments